From 9c4f96a7ebcecbef62e18b35520a8bde8c663966 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Tue, 1 Jul 2025 13:43:44 +0800 Subject: [PATCH 1/7] feat:add support for mongo driver v2. --- drivers/mongov2/context.go | 48 +++++++++++ drivers/mongov2/contract.go | 23 ++++++ drivers/mongov2/example_test.go | 105 +++++++++++++++++++++++++ drivers/mongov2/factory.go | 16 ++++ drivers/mongov2/go.mod | 27 +++++++ drivers/mongov2/goroutine_leak_test.go | 14 ++++ drivers/mongov2/settings.go | 101 ++++++++++++++++++++++++ drivers/mongov2/settings_test.go | 102 ++++++++++++++++++++++++ drivers/mongov2/transaction.go | 86 ++++++++++++++++++++ go.work | 1 + go.work.sum | 8 ++ 11 files changed, 531 insertions(+) create mode 100644 drivers/mongov2/context.go create mode 100644 drivers/mongov2/contract.go create mode 100644 drivers/mongov2/example_test.go create mode 100644 drivers/mongov2/factory.go create mode 100644 drivers/mongov2/go.mod create mode 100644 drivers/mongov2/goroutine_leak_test.go create mode 100644 drivers/mongov2/settings.go create mode 100644 drivers/mongov2/settings_test.go create mode 100644 drivers/mongov2/transaction.go diff --git a/drivers/mongov2/context.go b/drivers/mongov2/context.go new file mode 100644 index 0000000..125886c --- /dev/null +++ b/drivers/mongov2/context.go @@ -0,0 +1,48 @@ +package mongov2 + +import ( + "context" + "github.com/avito-tech/go-transaction-manager/trm/v2" + "go.mongodb.org/mongo-driver/v2/mongo" + + trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" +) + +// DefaultCtxGetter is the CtxGetter with settings.DefaultCtxKey. +var DefaultCtxGetter = NewCtxGetter(trmcontext.DefaultManager) + +// CtxGetter gets Tr from trm.СtxManager by casting trm.Transaction to Tr. +type CtxGetter struct { + ctxManager trm.СtxManager +} + +// NewCtxGetter returns *CtxGetter to get Tr from context.Context. +func NewCtxGetter(c trm.СtxManager) *CtxGetter { + return &CtxGetter{ctxManager: c} +} + +// DefaultTrOrDB returns mongo.Session from context.Context or DB(mongo.Session) otherwise. +func (c *CtxGetter) DefaultTrOrDB(ctx context.Context, db *mongo.Session) *mongo.Session { + if tr := c.ctxManager.Default(ctx); tr != nil { + return c.convert(tr) + } + + return db +} + +// TrOrDB returns mongo.Session from context.Context by trm.CtxKey or DB(mongo.Session) otherwise. +func (c *CtxGetter) TrOrDB(ctx context.Context, key trm.CtxKey, db *mongo.Session) *mongo.Session { + if tr := c.ctxManager.ByKey(ctx, key); tr != nil { + return c.convert(tr) + } + + return db +} + +func (c *CtxGetter) convert(tr trm.Transaction) *mongo.Session { + if tx, ok := tr.Transaction().(*mongo.Session); ok { + return tx + } + + return nil +} diff --git a/drivers/mongov2/contract.go b/drivers/mongov2/contract.go new file mode 100644 index 0000000..0a51737 --- /dev/null +++ b/drivers/mongov2/contract.go @@ -0,0 +1,23 @@ +package mongov2 + +import ( + "context" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readpref" + + "go.mongodb.org/mongo-driver/v2/mongo" +) + +//nolint:interfacebloat +type client interface { + Disconnect(ctx context.Context) error + Ping(ctx context.Context, rp *readpref.ReadPref) error + StartSession(opts ...options.Lister[options.SessionOptions]) (*mongo.Session, error) + Database(name string, opts ...options.Lister[options.DatabaseOptions]) *mongo.Database + ListDatabases(ctx context.Context, filter interface{}, opts ...options.Lister[options.ListDatabasesOptions]) (mongo.ListDatabasesResult, error) + ListDatabaseNames(ctx context.Context, filter interface{}, opts ...options.Lister[options.ListDatabasesOptions]) ([]string, error) + UseSession(ctx context.Context, fn func(context.Context) error) error + UseSessionWithOptions(ctx context.Context, opts *options.SessionOptionsBuilder, fn func(context.Context) error) error + Watch(ctx context.Context, pipeline interface{}, opts ...options.Lister[options.ChangeStreamOptions]) (*mongo.ChangeStream, error) + NumberSessionsInProgress() int +} diff --git a/drivers/mongov2/example_test.go b/drivers/mongov2/example_test.go new file mode 100644 index 0000000..3fe0aa8 --- /dev/null +++ b/drivers/mongov2/example_test.go @@ -0,0 +1,105 @@ +//go:build with_real_db +// +build with_real_db + +package mongov2_test + +import ( + "context" + "fmt" + trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + + "github.com/avito-tech/go-transaction-manager/trm/v2/manager" +) + +// Example demonstrates the implementation of the Repository pattern by trm.Manager. +func Example() { + ctx := context.Background() + + client, err := mongo.Connect(ctx, options.Client(). + ApplyURI("mongodb://127.0.0.1:27017/?directConnection=true")) + checkErr(err) + defer client.Disconnect(ctx) + + collection := client.Database("test").Collection("users") + + r := newRepo(collection, trmmongo.DefaultCtxGetter) + + u := &user{ + ID: 1, + Username: "username", + } + + trManager := manager.Must( + trmmongo.NewDefaultFactory(client), + manager.WithCtxManager(trmcontext.DefaultManager), + ) + + err = trManager.Do(ctx, func(ctx context.Context) error { + if err := r.Save(ctx, u); err != nil { + return err + } + + return trManager.Do(ctx, func(ctx context.Context) error { + u.Username = "new_username" + + return r.Save(ctx, u) + }) + }) + checkErr(err) + + userFromDB, err := r.GetByID(ctx, u.ID) + checkErr(err) + + fmt.Println(userFromDB) + + // Output: &{1 new_username} +} + +type repo struct { + collection *mongo.Collection + getter *trmmongo.CtxGetter +} + +func newRepo(collection *mongo.Collection, c *trmmongo.CtxGetter) *repo { + return &repo{ + collection: collection, + getter: c, + } +} + +type user struct { + ID int64 `bson:"_id"` + Username string `bson:"username"` +} + +func (r *repo) GetByID(ctx context.Context, id int64) (*user, error) { + var result user + + err := r.collection.FindOne(ctx, bson.M{"_id": id}).Decode(&result) + + return &result, err +} + +func (r *repo) Save(ctx context.Context, u *user) error { + if err := r.collection.FindOneAndUpdate( + ctx, + bson.M{"_id": u.ID}, + bson.M{"$set": u}, + options.FindOneAndUpdate(). + SetReturnDocument(options.After). + SetUpsert(true), + ).Err(); err != nil { + return err + } + + return nil +} + +func checkErr(err error, args ...interface{}) { + if err != nil { + panic(fmt.Sprint(append([]interface{}{err}, args...)...)) + } +} diff --git a/drivers/mongov2/factory.go b/drivers/mongov2/factory.go new file mode 100644 index 0000000..03faff5 --- /dev/null +++ b/drivers/mongov2/factory.go @@ -0,0 +1,16 @@ +package mongov2 + +import ( + "context" + + trm "github.com/avito-tech/go-transaction-manager/trm/v2" +) + +// NewDefaultFactory creates default trm.Transaction(mongo.Session). +func NewDefaultFactory(client client) trm.TrFactory { + return func(ctx context.Context, trms trm.Settings) (context.Context, trm.Transaction, error) { + s, _ := trms.(Settings) + + return NewTransaction(ctx, s.SessionOpts(), s.TransactionOpts(), client) + } +} diff --git a/drivers/mongov2/go.mod b/drivers/mongov2/go.mod new file mode 100644 index 0000000..008899a --- /dev/null +++ b/drivers/mongov2/go.mod @@ -0,0 +1,27 @@ +module github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2 + +go 1.21 + +require ( + github.com/avito-tech/go-transaction-manager/trm/v2 v2.0.0-rc10 + github.com/stretchr/testify v1.8.2 + go.mongodb.org/mongo-driver/v2 v2.0.0 + go.uber.org/goleak v1.3.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/klauspost/compress v1.16.7 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/sync v0.9.0 // indirect + golang.org/x/text v0.20.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/drivers/mongov2/goroutine_leak_test.go b/drivers/mongov2/goroutine_leak_test.go new file mode 100644 index 0000000..a95a26e --- /dev/null +++ b/drivers/mongov2/goroutine_leak_test.go @@ -0,0 +1,14 @@ +//go:build go1.24 +// +build go1.24 + +package mongov2 + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/drivers/mongov2/settings.go b/drivers/mongov2/settings.go new file mode 100644 index 0000000..e4aaff4 --- /dev/null +++ b/drivers/mongov2/settings.go @@ -0,0 +1,101 @@ +package mongov2 + +import ( + trm "github.com/avito-tech/go-transaction-manager/trm/v2" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// Opt is a type to configure Settings. +type Opt func(*Settings) error + +// WithSessionOpts sets up options.SessionOptions for the Settings. +func WithSessionOpts(opts *options.SessionOptionsBuilder) Opt { + return func(s *Settings) error { + *s = s.setSessionOpts(opts) + + return nil + } +} + +// WithTransactionOpts sets up options.TransactionOptions for the Settings. +func WithTransactionOpts(opts *options.TransactionOptionsBuilder) Opt { + return func(s *Settings) error { + *s = s.setTransactionOpts(opts) + + return nil + } +} + +// Settings contains settings for mongo.Transaction. +type Settings struct { + trm.Settings + sessionOpts *options.SessionOptionsBuilder + transactionOpts *options.TransactionOptionsBuilder +} + +// NewSettings creates Settings. +func NewSettings(trms trm.Settings, oo ...Opt) (Settings, error) { + s := &Settings{ + Settings: trms, + sessionOpts: nil, + transactionOpts: nil, + } + + for _, o := range oo { + if err := o(s); err != nil { + return Settings{}, err + } + } + + return *s, nil +} + +// MustSettings returns Settings if err is nil and panics otherwise. +func MustSettings(trms trm.Settings, oo ...Opt) Settings { + s, err := NewSettings(trms, oo...) + if err != nil { + panic(err) + } + + return s +} + +// EnrichBy fills nil properties from external Settings. +func (s Settings) EnrichBy(in trm.Settings) trm.Settings { + external, ok := in.(Settings) + if ok { + if s.SessionOpts() == nil { + s = s.setSessionOpts(external.SessionOpts()) + } + + if s.TransactionOpts() == nil { + s = s.setTransactionOpts(external.TransactionOpts()) + } + } + + s.Settings = s.Settings.EnrichBy(in) + + return s +} + +// SessionOpts returns *options.SessionOptions for the trm.Transaction. +func (s Settings) SessionOpts() *options.SessionOptionsBuilder { + return s.sessionOpts +} + +func (s Settings) setSessionOpts(opts *options.SessionOptionsBuilder) Settings { + s.sessionOpts = opts + + return s +} + +// TransactionOpts returns trm.CtxKey for the trm.Transaction. +func (s Settings) TransactionOpts() *options.TransactionOptionsBuilder { + return s.transactionOpts +} + +func (s Settings) setTransactionOpts(opts *options.TransactionOptionsBuilder) Settings { + s.transactionOpts = opts + + return s +} diff --git a/drivers/mongov2/settings_test.go b/drivers/mongov2/settings_test.go new file mode 100644 index 0000000..e08bec9 --- /dev/null +++ b/drivers/mongov2/settings_test.go @@ -0,0 +1,102 @@ +package mongov2 + +import ( + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readconcern" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/avito-tech/go-transaction-manager/trm/v2" + "github.com/avito-tech/go-transaction-manager/trm/v2/settings" +) + +func TestSettings_EnrichBy(t *testing.T) { + t.Parallel() + + type args struct { + external trm.Settings + } + + tests := map[string]struct { + settings Settings + args args + want trm.Settings + }{ + "update_default": { + settings: MustSettings(settings.Must()), + args: args{ + external: MustSettings( + settings.Must(settings.WithCancelable(true)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetReadConcern(readconcern.Majority())), + ), + }, + want: MustSettings( + settings.Must(settings.WithCancelable(true)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetReadConcern(readconcern.Majority())), + ), + }, + "without_update": { + settings: MustSettings( + settings.Must(settings.WithCancelable(true)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetReadConcern(readconcern.Majority())), + ), + args: args{ + external: MustSettings( + settings.Must(settings.WithCancelable(false)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(false)), + WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetReadConcern(readconcern.Local())), + ), + }, + want: MustSettings( + settings.Must(settings.WithCancelable(true)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetReadConcern(readconcern.Majority())), + ), + }, + "update_only_trm.Settings": { + settings: MustSettings( + settings.Must(), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + ), + args: args{ + external: settings.Must(settings.WithCancelable(true)), + }, + want: MustSettings( + settings.Must(settings.WithCancelable(true)), + WithSessionOpts((&options.SessionOptionsBuilder{}). + SetCausalConsistency(true)), + ), + }, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := tt.settings.EnrichBy(tt.args.external) + + //assert.Equal(t, tt.want, got) + + t.Helper() + assert.Equal(t, tt.want.CtxKey(), got.CtxKey()) + assert.Equal(t, tt.want.Propagation(), got.Propagation()) + assert.Equal(t, tt.want.Cancelable(), got.Cancelable()) + assert.Equal(t, tt.want.TimeoutOrNil(), got.TimeoutOrNil()) + }) + } +} diff --git a/drivers/mongov2/transaction.go b/drivers/mongov2/transaction.go new file mode 100644 index 0000000..0787e6b --- /dev/null +++ b/drivers/mongov2/transaction.go @@ -0,0 +1,86 @@ +// Package mongo is an implementation of trm.Transaction interface by Transaction for mongo.Client. +package mongov2 + +import ( + "context" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + + "github.com/avito-tech/go-transaction-manager/trm/v2/drivers" +) + +// Transaction is trm.Transaction for mongo.Client. +type Transaction struct { + session *mongo.Session + isClosed *drivers.IsClosed +} + +// NewTransaction creates trm.Transaction for mongo.Client. +func NewTransaction( + ctx context.Context, + sessionOptions *options.SessionOptionsBuilder, + trOpts *options.TransactionOptionsBuilder, + client client, +) (context.Context, *Transaction, error) { + s, err := client.StartSession(sessionOptions) + if err != nil { + return ctx, nil, err + } + + if err = s.StartTransaction(trOpts); err != nil { + defer s.EndSession(ctx) + + return ctx, nil, err + } + + tr := &Transaction{session: s, isClosed: drivers.NewIsClosed()} + + go tr.awaitDone(ctx) + + return mongo.NewSessionContext(ctx, tr.session), tr, nil +} + +func (t *Transaction) awaitDone(ctx context.Context) { + if ctx.Done() == nil { + return + } + + select { + case <-ctx.Done(): + t.isClosed.Close() + case <-t.isClosed.Closed(): + } +} + +// Transaction returns the real transaction mongo.Session. +func (t *Transaction) Transaction() interface{} { + return t.session +} + +// Commit the trm.Transaction. +func (t *Transaction) Commit(ctx context.Context) error { + defer t.isClosed.Close() + + defer t.session.EndSession(ctx) + + return t.session.CommitTransaction(ctx) +} + +// Rollback the trm.Transaction. +func (t *Transaction) Rollback(ctx context.Context) error { + defer t.isClosed.Close() + + defer t.session.EndSession(ctx) + + return t.session.AbortTransaction(ctx) +} + +// IsActive returns true if the transaction started but not committed or rolled back. +func (t *Transaction) IsActive() bool { + return t.isClosed.IsActive() +} + +// Closed returns a channel that's closed when transaction committed or rolled back. +func (t *Transaction) Closed() <-chan struct{} { + return t.isClosed.Closed() +} diff --git a/go.work b/go.work index 47ded15..64d7f44 100644 --- a/go.work +++ b/go.work @@ -4,6 +4,7 @@ use ( ./drivers/goredis8 ./drivers/gorm ./drivers/mongo + ./drivers/mongov2 ./drivers/pgxv4 ./drivers/pgxv5 ./drivers/sql diff --git a/go.work.sum b/go.work.sum index bbb0461..a8537a9 100644 --- a/go.work.sum +++ b/go.work.sum @@ -6,6 +6,7 @@ github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f h1:JOrtw2xFKzlg+cbHpyrpLDmnN1HqhBfnX7WDiW7eG2c= +github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= @@ -19,6 +20,7 @@ github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8 github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= github.com/kisielk/gotool v1.0.0 h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46 h1:veS9QfglfvqAw2e+eeNT/SbGySq8ajECXJ9e4fPoLhY= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= @@ -33,18 +35,24 @@ github.com/stretchr/objx v0.5.1/go.mod h1:/iHQpkQwBD6DLUmQ4pE+s1TXdob1mORJ4/UFdr github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/zenazn/goji v0.9.0 h1:RSQQAbXGArQ0dIDEq+PI6WqN6if+5KHu6x2Cx/GXLTQ= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2 h1:IRJeR9r1pYWsHKTRe/IInb7lYvbBVIqOgsX/u0mbOWY= golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +gopkg.in/errgo.v2 v2.1.0 h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= From babd01b3d2c66c0d892acb4588b10fc728b45f91 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Tue, 1 Jul 2025 14:43:51 +0800 Subject: [PATCH 2/7] feat:mongov2 add unit test --- drivers/mongov2/context_test.go | 59 + drivers/mongov2/go.mod | 4 +- drivers/mongov2/go.sum | 113 ++ drivers/mongov2/goroutine_leak_test.go | 3 +- .../internal/assert/assertion_compare.go | 481 +++++++ .../assert/assertion_compare_can_convert.go | 18 + .../assert/assertion_compare_go1.17_test.go | 184 +++ .../assert/assertion_compare_legacy.go | 18 + .../internal/assert/assertion_compare_test.go | 455 ++++++ .../internal/assert/assertion_format.go | 325 +++++ .../internal/assert/assertion_mongo.go | 126 ++ .../internal/assert/assertion_mongo_test.go | 125 ++ drivers/mongov2/internal/assert/assertions.go | 1075 ++++++++++++++ .../internal/assert/assertions_test.go | 1231 ++++++++++++++++ drivers/mongov2/internal/assert/difflib.go | 766 ++++++++++ .../mongov2/internal/assert/difflib_test.go | 326 +++++ drivers/mongov2/internal/csfle/csfle.go | 40 + .../internal/driverutil/description.go | 493 +++++++ drivers/mongov2/internal/driverutil/hello.go | 128 ++ .../mongov2/internal/driverutil/operation.go | 31 + .../mongov2/internal/failpoint/failpoint.go | 63 + .../mongov2/internal/integtest/integtest.go | 295 ++++ .../mongov2/internal/mongoutil/mongoutil.go | 85 ++ .../internal/mongoutil/mongoutil_test.go | 34 + .../mongov2/internal/mtest/csfle_enabled.go | 16 + .../internal/mtest/csfle_not_enabled.go | 16 + .../internal/mtest/deployment_helpers.go | 120 ++ drivers/mongov2/internal/mtest/doc.go | 9 + .../mongov2/internal/mtest/global_state.go | 96 ++ drivers/mongov2/internal/mtest/mongotest.go | 874 +++++++++++ drivers/mongov2/internal/mtest/options.go | 283 ++++ .../mongov2/internal/mtest/proxy_dialer.go | 186 +++ .../internal/mtest/received_message.go | 124 ++ .../mongov2/internal/mtest/sent_message.go | 195 +++ drivers/mongov2/internal/mtest/setup.go | 376 +++++ .../mongov2/internal/mtest/setup_options.go | 25 + .../internal/mtest/wiremessage_helpers.go | 67 + drivers/mongov2/internal/require/require.go | 819 +++++++++++ .../serverselector/server_selector.go | 359 +++++ .../serverselector/server_selector_test.go | 1278 +++++++++++++++++ drivers/mongov2/internal/spectest/spectest.go | 35 + drivers/mongov2/settings_test.go | 5 +- drivers/mongov2/transaction_test.go | 231 +++ 43 files changed, 11586 insertions(+), 6 deletions(-) create mode 100644 drivers/mongov2/context_test.go create mode 100644 drivers/mongov2/go.sum create mode 100644 drivers/mongov2/internal/assert/assertion_compare.go create mode 100644 drivers/mongov2/internal/assert/assertion_compare_can_convert.go create mode 100644 drivers/mongov2/internal/assert/assertion_compare_go1.17_test.go create mode 100644 drivers/mongov2/internal/assert/assertion_compare_legacy.go create mode 100644 drivers/mongov2/internal/assert/assertion_compare_test.go create mode 100644 drivers/mongov2/internal/assert/assertion_format.go create mode 100644 drivers/mongov2/internal/assert/assertion_mongo.go create mode 100644 drivers/mongov2/internal/assert/assertion_mongo_test.go create mode 100644 drivers/mongov2/internal/assert/assertions.go create mode 100644 drivers/mongov2/internal/assert/assertions_test.go create mode 100644 drivers/mongov2/internal/assert/difflib.go create mode 100644 drivers/mongov2/internal/assert/difflib_test.go create mode 100644 drivers/mongov2/internal/csfle/csfle.go create mode 100644 drivers/mongov2/internal/driverutil/description.go create mode 100644 drivers/mongov2/internal/driverutil/hello.go create mode 100644 drivers/mongov2/internal/driverutil/operation.go create mode 100644 drivers/mongov2/internal/failpoint/failpoint.go create mode 100644 drivers/mongov2/internal/integtest/integtest.go create mode 100644 drivers/mongov2/internal/mongoutil/mongoutil.go create mode 100644 drivers/mongov2/internal/mongoutil/mongoutil_test.go create mode 100644 drivers/mongov2/internal/mtest/csfle_enabled.go create mode 100644 drivers/mongov2/internal/mtest/csfle_not_enabled.go create mode 100644 drivers/mongov2/internal/mtest/deployment_helpers.go create mode 100644 drivers/mongov2/internal/mtest/doc.go create mode 100644 drivers/mongov2/internal/mtest/global_state.go create mode 100644 drivers/mongov2/internal/mtest/mongotest.go create mode 100644 drivers/mongov2/internal/mtest/options.go create mode 100644 drivers/mongov2/internal/mtest/proxy_dialer.go create mode 100644 drivers/mongov2/internal/mtest/received_message.go create mode 100644 drivers/mongov2/internal/mtest/sent_message.go create mode 100644 drivers/mongov2/internal/mtest/setup.go create mode 100644 drivers/mongov2/internal/mtest/setup_options.go create mode 100644 drivers/mongov2/internal/mtest/wiremessage_helpers.go create mode 100644 drivers/mongov2/internal/require/require.go create mode 100644 drivers/mongov2/internal/serverselector/server_selector.go create mode 100644 drivers/mongov2/internal/serverselector/server_selector_test.go create mode 100644 drivers/mongov2/internal/spectest/spectest.go create mode 100644 drivers/mongov2/transaction_test.go diff --git a/drivers/mongov2/context_test.go b/drivers/mongov2/context_test.go new file mode 100644 index 0000000..822a589 --- /dev/null +++ b/drivers/mongov2/context_test.go @@ -0,0 +1,59 @@ +//go:build go1.21 + +package mongov2 + +import ( + "context" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" + "testing" + + "github.com/avito-tech/go-transaction-manager/trm/v2/manager" + "github.com/avito-tech/go-transaction-manager/trm/v2/settings" + "github.com/stretchr/testify/require" +) + +func TestContext(t *testing.T) { + t.Parallel() + + ctx := context.Background() + mt := mtest.New( + t, + mtest.NewOptions().ClientType(mtest.Mock), + ) + + mt.Run("all", func(mt *mtest.T) { + mt.Parallel() + + m := manager.Must( + NewDefaultFactory(mt.Client), + ) + + err := m.Do(ctx, func(ctx context.Context) error { + tr := DefaultCtxGetter.TrOrDB(ctx, settings.DefaultCtxKey, nil) + require.NotNil(t, tr) + + tr = DefaultCtxGetter.DefaultTrOrDB(ctx, nil) + require.NotNil(t, tr) + + tr = DefaultCtxGetter.TrOrDB(ctx, "invalid ley", nil) + require.Nil(t, tr) + + err := m.Do(ctx, func(ctx context.Context) error { + tr = DefaultCtxGetter.DefaultTrOrDB(ctx, nil) + require.NotNil(t, tr) + + tr = DefaultCtxGetter.TrOrDB(ctx, settings.DefaultCtxKey, nil) + require.NotNil(t, tr) + + tr = DefaultCtxGetter.TrOrDB(ctx, "invalid ley", nil) + require.Nil(t, tr) + + return nil + }) + + return err + }) + + require.NoError(t, err) + }) +} diff --git a/drivers/mongov2/go.mod b/drivers/mongov2/go.mod index 008899a..e92c319 100644 --- a/drivers/mongov2/go.mod +++ b/drivers/mongov2/go.mod @@ -4,13 +4,15 @@ go 1.21 require ( github.com/avito-tech/go-transaction-manager/trm/v2 v2.0.0-rc10 + github.com/davecgh/go-spew v1.1.1 + github.com/google/go-cmp v0.6.0 github.com/stretchr/testify v1.8.2 go.mongodb.org/mongo-driver/v2 v2.0.0 go.uber.org/goleak v1.3.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/klauspost/compress v1.16.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/drivers/mongov2/go.sum b/drivers/mongov2/go.sum new file mode 100644 index 0000000..e27b58a --- /dev/null +++ b/drivers/mongov2/go.sum @@ -0,0 +1,113 @@ +github.com/DATA-DOG/go-sqlmock v1.5.1/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/avito-tech/go-transaction-manager/trm/v2 v2.0.0-rc10 h1:SqfNHnRw9CeroyLp4aVJVnmNaSemjbGy0nhiSGerGW4= +github.com/avito-tech/go-transaction-manager/trm/v2 v2.0.0-rc10/go.mod h1:qUNVecb/ahohzAvtGvjfWTeCOejgRRiO/2C4cDvtLjI= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver/v2 v2.0.0 h1:Jfd7XpdZa9yk3eY774bO7SWVb30noLSirL9nKTpavhI= +go.mongodb.org/mongo-driver/v2 v2.0.0/go.mod h1:nSjmNq4JUstE8IRZKTktLgMHM4F1fccL6HGX1yh+8RA= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/drivers/mongov2/goroutine_leak_test.go b/drivers/mongov2/goroutine_leak_test.go index a95a26e..1069540 100644 --- a/drivers/mongov2/goroutine_leak_test.go +++ b/drivers/mongov2/goroutine_leak_test.go @@ -1,5 +1,4 @@ -//go:build go1.24 -// +build go1.24 +//go:build go1.21 package mongov2 diff --git a/drivers/mongov2/internal/assert/assertion_compare.go b/drivers/mongov2/internal/assert/assertion_compare.go new file mode 100644 index 0000000..0a8307d --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_compare.go @@ -0,0 +1,481 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bytes" + "fmt" + "reflect" + "time" +) + +type CompareType int + +const ( + compareLess CompareType = iota - 1 + compareEqual + compareGreater +) + +var ( + intType = reflect.TypeOf(int(1)) + int8Type = reflect.TypeOf(int8(1)) + int16Type = reflect.TypeOf(int16(1)) + int32Type = reflect.TypeOf(int32(1)) + int64Type = reflect.TypeOf(int64(1)) + + uintType = reflect.TypeOf(uint(1)) + uint8Type = reflect.TypeOf(uint8(1)) + uint16Type = reflect.TypeOf(uint16(1)) + uint32Type = reflect.TypeOf(uint32(1)) + uint64Type = reflect.TypeOf(uint64(1)) + + float32Type = reflect.TypeOf(float32(1)) + float64Type = reflect.TypeOf(float64(1)) + + stringType = reflect.TypeOf("") + + timeType = reflect.TypeOf(time.Time{}) + bytesType = reflect.TypeOf([]byte{}) +) + +func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { + obj1Value := reflect.ValueOf(obj1) + obj2Value := reflect.ValueOf(obj2) + + // throughout this switch we try and avoid calling .Convert() if possible, + // as this has a pretty big performance impact + switch kind { + case reflect.Int: + { + intobj1, ok := obj1.(int) + if !ok { + intobj1 = obj1Value.Convert(intType).Interface().(int) + } + intobj2, ok := obj2.(int) + if !ok { + intobj2 = obj2Value.Convert(intType).Interface().(int) + } + if intobj1 > intobj2 { + return compareGreater, true + } + if intobj1 == intobj2 { + return compareEqual, true + } + if intobj1 < intobj2 { + return compareLess, true + } + } + case reflect.Int8: + { + int8obj1, ok := obj1.(int8) + if !ok { + int8obj1 = obj1Value.Convert(int8Type).Interface().(int8) + } + int8obj2, ok := obj2.(int8) + if !ok { + int8obj2 = obj2Value.Convert(int8Type).Interface().(int8) + } + if int8obj1 > int8obj2 { + return compareGreater, true + } + if int8obj1 == int8obj2 { + return compareEqual, true + } + if int8obj1 < int8obj2 { + return compareLess, true + } + } + case reflect.Int16: + { + int16obj1, ok := obj1.(int16) + if !ok { + int16obj1 = obj1Value.Convert(int16Type).Interface().(int16) + } + int16obj2, ok := obj2.(int16) + if !ok { + int16obj2 = obj2Value.Convert(int16Type).Interface().(int16) + } + if int16obj1 > int16obj2 { + return compareGreater, true + } + if int16obj1 == int16obj2 { + return compareEqual, true + } + if int16obj1 < int16obj2 { + return compareLess, true + } + } + case reflect.Int32: + { + int32obj1, ok := obj1.(int32) + if !ok { + int32obj1 = obj1Value.Convert(int32Type).Interface().(int32) + } + int32obj2, ok := obj2.(int32) + if !ok { + int32obj2 = obj2Value.Convert(int32Type).Interface().(int32) + } + if int32obj1 > int32obj2 { + return compareGreater, true + } + if int32obj1 == int32obj2 { + return compareEqual, true + } + if int32obj1 < int32obj2 { + return compareLess, true + } + } + case reflect.Int64: + { + int64obj1, ok := obj1.(int64) + if !ok { + int64obj1 = obj1Value.Convert(int64Type).Interface().(int64) + } + int64obj2, ok := obj2.(int64) + if !ok { + int64obj2 = obj2Value.Convert(int64Type).Interface().(int64) + } + if int64obj1 > int64obj2 { + return compareGreater, true + } + if int64obj1 == int64obj2 { + return compareEqual, true + } + if int64obj1 < int64obj2 { + return compareLess, true + } + } + case reflect.Uint: + { + uintobj1, ok := obj1.(uint) + if !ok { + uintobj1 = obj1Value.Convert(uintType).Interface().(uint) + } + uintobj2, ok := obj2.(uint) + if !ok { + uintobj2 = obj2Value.Convert(uintType).Interface().(uint) + } + if uintobj1 > uintobj2 { + return compareGreater, true + } + if uintobj1 == uintobj2 { + return compareEqual, true + } + if uintobj1 < uintobj2 { + return compareLess, true + } + } + case reflect.Uint8: + { + uint8obj1, ok := obj1.(uint8) + if !ok { + uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8) + } + uint8obj2, ok := obj2.(uint8) + if !ok { + uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8) + } + if uint8obj1 > uint8obj2 { + return compareGreater, true + } + if uint8obj1 == uint8obj2 { + return compareEqual, true + } + if uint8obj1 < uint8obj2 { + return compareLess, true + } + } + case reflect.Uint16: + { + uint16obj1, ok := obj1.(uint16) + if !ok { + uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16) + } + uint16obj2, ok := obj2.(uint16) + if !ok { + uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16) + } + if uint16obj1 > uint16obj2 { + return compareGreater, true + } + if uint16obj1 == uint16obj2 { + return compareEqual, true + } + if uint16obj1 < uint16obj2 { + return compareLess, true + } + } + case reflect.Uint32: + { + uint32obj1, ok := obj1.(uint32) + if !ok { + uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32) + } + uint32obj2, ok := obj2.(uint32) + if !ok { + uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32) + } + if uint32obj1 > uint32obj2 { + return compareGreater, true + } + if uint32obj1 == uint32obj2 { + return compareEqual, true + } + if uint32obj1 < uint32obj2 { + return compareLess, true + } + } + case reflect.Uint64: + { + uint64obj1, ok := obj1.(uint64) + if !ok { + uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64) + } + uint64obj2, ok := obj2.(uint64) + if !ok { + uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64) + } + if uint64obj1 > uint64obj2 { + return compareGreater, true + } + if uint64obj1 == uint64obj2 { + return compareEqual, true + } + if uint64obj1 < uint64obj2 { + return compareLess, true + } + } + case reflect.Float32: + { + float32obj1, ok := obj1.(float32) + if !ok { + float32obj1 = obj1Value.Convert(float32Type).Interface().(float32) + } + float32obj2, ok := obj2.(float32) + if !ok { + float32obj2 = obj2Value.Convert(float32Type).Interface().(float32) + } + if float32obj1 > float32obj2 { + return compareGreater, true + } + if float32obj1 == float32obj2 { + return compareEqual, true + } + if float32obj1 < float32obj2 { + return compareLess, true + } + } + case reflect.Float64: + { + float64obj1, ok := obj1.(float64) + if !ok { + float64obj1 = obj1Value.Convert(float64Type).Interface().(float64) + } + float64obj2, ok := obj2.(float64) + if !ok { + float64obj2 = obj2Value.Convert(float64Type).Interface().(float64) + } + if float64obj1 > float64obj2 { + return compareGreater, true + } + if float64obj1 == float64obj2 { + return compareEqual, true + } + if float64obj1 < float64obj2 { + return compareLess, true + } + } + case reflect.String: + { + stringobj1, ok := obj1.(string) + if !ok { + stringobj1 = obj1Value.Convert(stringType).Interface().(string) + } + stringobj2, ok := obj2.(string) + if !ok { + stringobj2 = obj2Value.Convert(stringType).Interface().(string) + } + if stringobj1 > stringobj2 { + return compareGreater, true + } + if stringobj1 == stringobj2 { + return compareEqual, true + } + if stringobj1 < stringobj2 { + return compareLess, true + } + } + // Check for known struct types we can check for compare results. + case reflect.Struct: + { + // All structs enter here. We're not interested in most types. + if !canConvert(obj1Value, timeType) { + break + } + + // time.Time can compared! + timeObj1, ok := obj1.(time.Time) + if !ok { + timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) + } + + timeObj2, ok := obj2.(time.Time) + if !ok { + timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) + } + + return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) + } + case reflect.Slice: + { + // We only care about the []byte type. + if !canConvert(obj1Value, bytesType) { + break + } + + // []byte can be compared! + bytesObj1, ok := obj1.([]byte) + if !ok { + bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte) + + } + bytesObj2, ok := obj2.([]byte) + if !ok { + bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) + } + + return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true + } + } + + return compareEqual, false +} + +// Greater asserts that the first element is greater than the second +// +// assert.Greater(t, 2, 1) +// assert.Greater(t, float64(2), float64(1)) +// assert.Greater(t, "b", "a") +func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) +} + +// GreaterOrEqual asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqual(t, 2, 1) +// assert.GreaterOrEqual(t, 2, 2) +// assert.GreaterOrEqual(t, "b", "a") +// assert.GreaterOrEqual(t, "b", "b") +func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) +} + +// Less asserts that the first element is less than the second +// +// assert.Less(t, 1, 2) +// assert.Less(t, float64(1), float64(2)) +// assert.Less(t, "a", "b") +func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) +} + +// LessOrEqual asserts that the first element is less than or equal to the second +// +// assert.LessOrEqual(t, 1, 2) +// assert.LessOrEqual(t, 2, 2) +// assert.LessOrEqual(t, "a", "b") +// assert.LessOrEqual(t, "b", "b") +func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) +} + +// Positive asserts that the specified element is positive +// +// assert.Positive(t, 1) +// assert.Positive(t, 1.23) +func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + zero := reflect.Zero(reflect.TypeOf(e)) + return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...) +} + +// Negative asserts that the specified element is negative +// +// assert.Negative(t, -1) +// assert.Negative(t, -1.23) +func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + zero := reflect.Zero(reflect.TypeOf(e)) + return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...) +} + +func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + e1Kind := reflect.ValueOf(e1).Kind() + e2Kind := reflect.ValueOf(e2).Kind() + if e1Kind != e2Kind { + return Fail(t, "Elements should be the same type", msgAndArgs...) + } + + compareResult, isComparable := compare(e1, e2, e1Kind) + if !isComparable { + return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...) + } + + if !containsValue(allowedComparesResults, compareResult) { + return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...) + } + + return true +} + +func containsValue(values []CompareType, value CompareType) bool { + for _, v := range values { + if v == value { + return true + } + } + + return false +} + +// CompareErrors asserts two errors +func CompareErrors(err1, err2 error) bool { + if err1 == nil && err2 == nil { + return true + } + + if err1 == nil || err2 == nil { + return false + } + + if err1.Error() != err2.Error() { + return false + } + + return true +} diff --git a/drivers/mongov2/internal/assert/assertion_compare_can_convert.go b/drivers/mongov2/internal/assert/assertion_compare_can_convert.go new file mode 100644 index 0000000..ff7f9e8 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_compare_can_convert.go @@ -0,0 +1,18 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_can_convert.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +//go:build go1.17 +// +build go1.17 + +package assert + +import "reflect" + +// Wrapper around reflect.Value.CanConvert, for compatibility +// reasons. +func canConvert(value reflect.Value, to reflect.Type) bool { + return value.CanConvert(to) +} diff --git a/drivers/mongov2/internal/assert/assertion_compare_go1.17_test.go b/drivers/mongov2/internal/assert/assertion_compare_go1.17_test.go new file mode 100644 index 0000000..49ce459 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_compare_go1.17_test.go @@ -0,0 +1,184 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_go1.17_test.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +//go:build go1.17 +// +build go1.17 + +package assert + +import ( + "bytes" + "reflect" + "testing" + "time" +) + +func TestCompare17(t *testing.T) { + type customTime time.Time + type customBytes []byte + for _, currCase := range []struct { + less interface{} + greater interface{} + cType string + }{ + {less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"}, + {less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"}, + {less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"}, + {less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"}, + } { + resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object should be comparable for type " + currCase.cType) + } + + if resLess != compareLess { + t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType, + currCase.less, currCase.greater) + } + + resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resGreater != compareGreater { + t.Errorf("object greater should be greater than less for type " + currCase.cType) + } + + resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resEqual != 0 { + t.Errorf("objects should be equal for type " + currCase.cType) + } + } +} + +func TestGreater17(t *testing.T) { + mockT := new(testing.T) + + if !Greater(mockT, 2, 1) { + t.Error("Greater should return true") + } + + if Greater(mockT, 1, 1) { + t.Error("Greater should return false") + } + + if Greater(mockT, 1, 2) { + t.Error("Greater should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Greater(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Greater") + } +} + +func TestGreaterOrEqual17(t *testing.T) { + mockT := new(testing.T) + + if !GreaterOrEqual(mockT, 2, 1) { + t.Error("GreaterOrEqual should return true") + } + + if !GreaterOrEqual(mockT, 1, 1) { + t.Error("GreaterOrEqual should return true") + } + + if GreaterOrEqual(mockT, 1, 2) { + t.Error("GreaterOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, GreaterOrEqual(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.GreaterOrEqual") + } +} + +func TestLess17(t *testing.T) { + mockT := new(testing.T) + + if !Less(mockT, 1, 2) { + t.Error("Less should return true") + } + + if Less(mockT, 1, 1) { + t.Error("Less should return false") + } + + if Less(mockT, 2, 1) { + t.Error("Less should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Less(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Less") + } +} + +func TestLessOrEqual17(t *testing.T) { + mockT := new(testing.T) + + if !LessOrEqual(mockT, 1, 2) { + t.Error("LessOrEqual should return true") + } + + if !LessOrEqual(mockT, 1, 1) { + t.Error("LessOrEqual should return true") + } + + if LessOrEqual(mockT, 2, 1) { + t.Error("LessOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`}, + {less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, LessOrEqual(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.LessOrEqual") + } +} diff --git a/drivers/mongov2/internal/assert/assertion_compare_legacy.go b/drivers/mongov2/internal/assert/assertion_compare_legacy.go new file mode 100644 index 0000000..c23c7d1 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_compare_legacy.go @@ -0,0 +1,18 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_legacy.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +//go:build !go1.17 +// +build !go1.17 + +package assert + +import "reflect" + +// Older versions of Go does not have the reflect.Value.CanConvert +// method. +func canConvert(value reflect.Value, to reflect.Type) bool { + return false +} diff --git a/drivers/mongov2/internal/assert/assertion_compare_test.go b/drivers/mongov2/internal/assert/assertion_compare_test.go new file mode 100644 index 0000000..36acdd8 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_compare_test.go @@ -0,0 +1,455 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_test.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bytes" + "fmt" + "reflect" + "runtime" + "testing" +) + +func TestCompare(t *testing.T) { + type customInt int + type customInt8 int8 + type customInt16 int16 + type customInt32 int32 + type customInt64 int64 + type customUInt uint + type customUInt8 uint8 + type customUInt16 uint16 + type customUInt32 uint32 + type customUInt64 uint64 + type customFloat32 float32 + type customFloat64 float64 + type customString string + for _, currCase := range []struct { + less interface{} + greater interface{} + cType string + }{ + {less: customString("a"), greater: customString("b"), cType: "string"}, + {less: "a", greater: "b", cType: "string"}, + {less: customInt(1), greater: customInt(2), cType: "int"}, + {less: int(1), greater: int(2), cType: "int"}, + {less: customInt8(1), greater: customInt8(2), cType: "int8"}, + {less: int8(1), greater: int8(2), cType: "int8"}, + {less: customInt16(1), greater: customInt16(2), cType: "int16"}, + {less: int16(1), greater: int16(2), cType: "int16"}, + {less: customInt32(1), greater: customInt32(2), cType: "int32"}, + {less: int32(1), greater: int32(2), cType: "int32"}, + {less: customInt64(1), greater: customInt64(2), cType: "int64"}, + {less: int64(1), greater: int64(2), cType: "int64"}, + {less: customUInt(1), greater: customUInt(2), cType: "uint"}, + {less: uint8(1), greater: uint8(2), cType: "uint8"}, + {less: customUInt8(1), greater: customUInt8(2), cType: "uint8"}, + {less: uint16(1), greater: uint16(2), cType: "uint16"}, + {less: customUInt16(1), greater: customUInt16(2), cType: "uint16"}, + {less: uint32(1), greater: uint32(2), cType: "uint32"}, + {less: customUInt32(1), greater: customUInt32(2), cType: "uint32"}, + {less: uint64(1), greater: uint64(2), cType: "uint64"}, + {less: customUInt64(1), greater: customUInt64(2), cType: "uint64"}, + {less: float32(1.23), greater: float32(2.34), cType: "float32"}, + {less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"}, + {less: float64(1.23), greater: float64(2.34), cType: "float64"}, + {less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"}, + } { + resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object should be comparable for type " + currCase.cType) + } + + if resLess != compareLess { + t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType, + currCase.less, currCase.greater) + } + + resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resGreater != compareGreater { + t.Errorf("object greater should be greater than less for type " + currCase.cType) + } + + resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resEqual != 0 { + t.Errorf("objects should be equal for type " + currCase.cType) + } + } +} + +type outputT struct { + buf *bytes.Buffer + helpers map[string]struct{} +} + +// Implements TestingT +func (t *outputT) Errorf(format string, args ...interface{}) { + s := fmt.Sprintf(format, args...) + t.buf.WriteString(s) +} + +func (t *outputT) Helper() { + if t.helpers == nil { + t.helpers = make(map[string]struct{}) + } + t.helpers[callerName(1)] = struct{}{} +} + +// callerName gives the function name (qualified with a package path) +// for the caller after skip frames (where 0 means the current function). +func callerName(skip int) string { + // Make room for the skip PC. + var pc [1]uintptr + n := runtime.Callers(skip+2, pc[:]) // skip + runtime.Callers + callerName + if n == 0 { + panic("testing: zero callers found") + } + frames := runtime.CallersFrames(pc[:n]) + frame, _ := frames.Next() + return frame.Function +} + +func TestGreater(t *testing.T) { + mockT := new(testing.T) + + if !Greater(mockT, 2, 1) { + t.Error("Greater should return true") + } + + if Greater(mockT, 1, 1) { + t.Error("Greater should return false") + } + + if Greater(mockT, 1, 2) { + t.Error("Greater should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: "a", greater: "b", msg: `"a" is not greater than "b"`}, + {less: int(1), greater: int(2), msg: `"1" is not greater than "2"`}, + {less: int8(1), greater: int8(2), msg: `"1" is not greater than "2"`}, + {less: int16(1), greater: int16(2), msg: `"1" is not greater than "2"`}, + {less: int32(1), greater: int32(2), msg: `"1" is not greater than "2"`}, + {less: int64(1), greater: int64(2), msg: `"1" is not greater than "2"`}, + {less: uint8(1), greater: uint8(2), msg: `"1" is not greater than "2"`}, + {less: uint16(1), greater: uint16(2), msg: `"1" is not greater than "2"`}, + {less: uint32(1), greater: uint32(2), msg: `"1" is not greater than "2"`}, + {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`}, + {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`}, + {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Greater(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Greater") + } +} + +func TestGreaterOrEqual(t *testing.T) { + mockT := new(testing.T) + + if !GreaterOrEqual(mockT, 2, 1) { + t.Error("GreaterOrEqual should return true") + } + + if !GreaterOrEqual(mockT, 1, 1) { + t.Error("GreaterOrEqual should return true") + } + + if GreaterOrEqual(mockT, 1, 2) { + t.Error("GreaterOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: "a", greater: "b", msg: `"a" is not greater than or equal to "b"`}, + {less: int(1), greater: int(2), msg: `"1" is not greater than or equal to "2"`}, + {less: int8(1), greater: int8(2), msg: `"1" is not greater than or equal to "2"`}, + {less: int16(1), greater: int16(2), msg: `"1" is not greater than or equal to "2"`}, + {less: int32(1), greater: int32(2), msg: `"1" is not greater than or equal to "2"`}, + {less: int64(1), greater: int64(2), msg: `"1" is not greater than or equal to "2"`}, + {less: uint8(1), greater: uint8(2), msg: `"1" is not greater than or equal to "2"`}, + {less: uint16(1), greater: uint16(2), msg: `"1" is not greater than or equal to "2"`}, + {less: uint32(1), greater: uint32(2), msg: `"1" is not greater than or equal to "2"`}, + {less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`}, + {less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, + {less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, GreaterOrEqual(out, currCase.less, currCase.greater)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.GreaterOrEqual") + } +} + +func TestLess(t *testing.T) { + mockT := new(testing.T) + + if !Less(mockT, 1, 2) { + t.Error("Less should return true") + } + + if Less(mockT, 1, 1) { + t.Error("Less should return false") + } + + if Less(mockT, 2, 1) { + t.Error("Less should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: "a", greater: "b", msg: `"b" is not less than "a"`}, + {less: int(1), greater: int(2), msg: `"2" is not less than "1"`}, + {less: int8(1), greater: int8(2), msg: `"2" is not less than "1"`}, + {less: int16(1), greater: int16(2), msg: `"2" is not less than "1"`}, + {less: int32(1), greater: int32(2), msg: `"2" is not less than "1"`}, + {less: int64(1), greater: int64(2), msg: `"2" is not less than "1"`}, + {less: uint8(1), greater: uint8(2), msg: `"2" is not less than "1"`}, + {less: uint16(1), greater: uint16(2), msg: `"2" is not less than "1"`}, + {less: uint32(1), greater: uint32(2), msg: `"2" is not less than "1"`}, + {less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`}, + {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`}, + {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Less(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Less") + } +} + +func TestLessOrEqual(t *testing.T) { + mockT := new(testing.T) + + if !LessOrEqual(mockT, 1, 2) { + t.Error("LessOrEqual should return true") + } + + if !LessOrEqual(mockT, 1, 1) { + t.Error("LessOrEqual should return true") + } + + if LessOrEqual(mockT, 2, 1) { + t.Error("LessOrEqual should return false") + } + + // Check error report + for _, currCase := range []struct { + less interface{} + greater interface{} + msg string + }{ + {less: "a", greater: "b", msg: `"b" is not less than or equal to "a"`}, + {less: int(1), greater: int(2), msg: `"2" is not less than or equal to "1"`}, + {less: int8(1), greater: int8(2), msg: `"2" is not less than or equal to "1"`}, + {less: int16(1), greater: int16(2), msg: `"2" is not less than or equal to "1"`}, + {less: int32(1), greater: int32(2), msg: `"2" is not less than or equal to "1"`}, + {less: int64(1), greater: int64(2), msg: `"2" is not less than or equal to "1"`}, + {less: uint8(1), greater: uint8(2), msg: `"2" is not less than or equal to "1"`}, + {less: uint16(1), greater: uint16(2), msg: `"2" is not less than or equal to "1"`}, + {less: uint32(1), greater: uint32(2), msg: `"2" is not less than or equal to "1"`}, + {less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`}, + {less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, + {less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, LessOrEqual(out, currCase.greater, currCase.less)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.LessOrEqual") + } +} + +func TestPositive(t *testing.T) { + mockT := new(testing.T) + + if !Positive(mockT, 1) { + t.Error("Positive should return true") + } + + if !Positive(mockT, 1.23) { + t.Error("Positive should return true") + } + + if Positive(mockT, -1) { + t.Error("Positive should return false") + } + + if Positive(mockT, -1.23) { + t.Error("Positive should return false") + } + + // Check error report + for _, currCase := range []struct { + e interface{} + msg string + }{ + {e: int(-1), msg: `"-1" is not positive`}, + {e: int8(-1), msg: `"-1" is not positive`}, + {e: int16(-1), msg: `"-1" is not positive`}, + {e: int32(-1), msg: `"-1" is not positive`}, + {e: int64(-1), msg: `"-1" is not positive`}, + {e: float32(-1.23), msg: `"-1.23" is not positive`}, + {e: float64(-1.23), msg: `"-1.23" is not positive`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Positive(out, currCase.e)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Positive") + } +} + +func TestNegative(t *testing.T) { + mockT := new(testing.T) + + if !Negative(mockT, -1) { + t.Error("Negative should return true") + } + + if !Negative(mockT, -1.23) { + t.Error("Negative should return true") + } + + if Negative(mockT, 1) { + t.Error("Negative should return false") + } + + if Negative(mockT, 1.23) { + t.Error("Negative should return false") + } + + // Check error report + for _, currCase := range []struct { + e interface{} + msg string + }{ + {e: int(1), msg: `"1" is not negative`}, + {e: int8(1), msg: `"1" is not negative`}, + {e: int16(1), msg: `"1" is not negative`}, + {e: int32(1), msg: `"1" is not negative`}, + {e: int64(1), msg: `"1" is not negative`}, + {e: float32(1.23), msg: `"1.23" is not negative`}, + {e: float64(1.23), msg: `"1.23" is not negative`}, + } { + out := &outputT{buf: bytes.NewBuffer(nil)} + False(t, Negative(out, currCase.e)) + Contains(t, out.buf.String(), currCase.msg) + Contains(t, out.helpers, "go.mongodb.org/mongo-driver/v2/internal/assert.Negative") + } +} + +func Test_compareTwoValuesDifferentValuesTypes(t *testing.T) { + mockT := new(testing.T) + + for _, currCase := range []struct { + v1 interface{} + v2 interface{} + compareResult bool + }{ + {v1: 123, v2: "abc"}, + {v1: "abc", v2: 123456}, + {v1: float64(12), v2: "123"}, + {v1: "float(12)", v2: float64(1)}, + } { + compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage") + False(t, compareResult) + } +} + +func Test_compareTwoValuesNotComparableValues(t *testing.T) { + mockT := new(testing.T) + + type CompareStruct struct { + } + + for _, currCase := range []struct { + v1 interface{} + v2 interface{} + }{ + {v1: CompareStruct{}, v2: CompareStruct{}}, + {v1: map[string]int{}, v2: map[string]int{}}, + {v1: make([]int, 5), v2: make([]int, 5)}, + } { + compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage") + False(t, compareResult) + } +} + +func Test_compareTwoValuesCorrectCompareResult(t *testing.T) { + mockT := new(testing.T) + + for _, currCase := range []struct { + v1 interface{} + v2 interface{} + compareTypes []CompareType + }{ + {v1: 1, v2: 2, compareTypes: []CompareType{compareLess}}, + {v1: 1, v2: 2, compareTypes: []CompareType{compareLess, compareEqual}}, + {v1: 2, v2: 2, compareTypes: []CompareType{compareGreater, compareEqual}}, + {v1: 2, v2: 2, compareTypes: []CompareType{compareEqual}}, + {v1: 2, v2: 1, compareTypes: []CompareType{compareEqual, compareGreater}}, + {v1: 2, v2: 1, compareTypes: []CompareType{compareGreater}}, + } { + compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, currCase.compareTypes, "testFailMessage") + True(t, compareResult) + } +} + +func Test_containsValue(t *testing.T) { + for _, currCase := range []struct { + values []CompareType + value CompareType + result bool + }{ + {values: []CompareType{compareGreater}, value: compareGreater, result: true}, + {values: []CompareType{compareGreater, compareLess}, value: compareGreater, result: true}, + {values: []CompareType{compareGreater, compareLess}, value: compareLess, result: true}, + {values: []CompareType{compareGreater, compareLess}, value: compareEqual, result: false}, + } { + compareResult := containsValue(currCase.values, currCase.value) + Equal(t, currCase.result, compareResult) + } +} + +func TestComparingMsgAndArgsForwarding(t *testing.T) { + msgAndArgs := []interface{}{"format %s %x", "this", 0xc001} + expectedOutput := "format this c001\n" + funcs := []func(t TestingT){ + func(t TestingT) { Greater(t, 1, 2, msgAndArgs...) }, + func(t TestingT) { GreaterOrEqual(t, 1, 2, msgAndArgs...) }, + func(t TestingT) { Less(t, 2, 1, msgAndArgs...) }, + func(t TestingT) { LessOrEqual(t, 2, 1, msgAndArgs...) }, + func(t TestingT) { Positive(t, 0, msgAndArgs...) }, + func(t TestingT) { Negative(t, 0, msgAndArgs...) }, + } + for _, f := range funcs { + out := &outputT{buf: bytes.NewBuffer(nil)} + f(out) + Contains(t, out.buf.String(), expectedOutput) + } +} diff --git a/drivers/mongov2/internal/assert/assertion_format.go b/drivers/mongov2/internal/assert/assertion_format.go new file mode 100644 index 0000000..474a5f4 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_format.go @@ -0,0 +1,325 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_format.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package assert + +import ( + time "time" +) + +// Containsf asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Contains(t, s, contains, append([]interface{}{msg}, args...)...) +} + +// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) +} + +// Equalf asserts that two objects are equal. +// +// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Equal(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// EqualErrorf asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...) +} + +// EqualValuesf asserts that two objects are equal or convertible to the same types +// and equal. +// +// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// Errorf asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Errorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } +func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Error(t, err, append([]interface{}{msg}, args...)...) +} + +// ErrorContainsf asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...) +} + +// Eventuallyf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...) +} + +// Failf reports a failure through +func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, failureMessage, append([]interface{}{msg}, args...)...) +} + +// FailNowf fails test +func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...) +} + +// Falsef asserts that the specified value is false. +// +// assert.Falsef(t, myBool, "error message %s", "formatted") +func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return False(t, value, append([]interface{}{msg}, args...)...) +} + +// Greaterf asserts that the first element is greater than the second +// +// assert.Greaterf(t, 2, 1, "error message %s", "formatted") +// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Greater(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// GreaterOrEqualf asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// InDeltaf asserts that the two numerals are within delta of each other. +// +// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} + +// IsTypef asserts that the specified objects are of the same type. +func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...) +} + +// Lenf asserts that the specified object has specific length. +// Lenf also fails if the object has a type that len() not accept. +// +// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Len(t, object, length, append([]interface{}{msg}, args...)...) +} + +// Lessf asserts that the first element is less than the second +// +// assert.Lessf(t, 1, 2, "error message %s", "formatted") +// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// assert.Lessf(t, "a", "b", "error message %s", "formatted") +func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Less(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// LessOrEqualf asserts that the first element is less than or equal to the second +// +// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...) +} + +// Negativef asserts that the specified element is negative +// +// assert.Negativef(t, -1, "error message %s", "formatted") +// assert.Negativef(t, -1.23, "error message %s", "formatted") +func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Negative(t, e, append([]interface{}{msg}, args...)...) +} + +// Nilf asserts that the specified object is nil. +// +// assert.Nilf(t, err, "error message %s", "formatted") +func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Nil(t, object, append([]interface{}{msg}, args...)...) +} + +// NoErrorf asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoErrorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NoError(t, err, append([]interface{}{msg}, args...)...) +} + +// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) +} + +// NotEqualf asserts that the specified values are NOT equal. +// +// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// NotEqualValuesf asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) +} + +// NotNilf asserts that the specified object is not nil. +// +// assert.NotNilf(t, err, "error message %s", "formatted") +func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotNil(t, object, append([]interface{}{msg}, args...)...) +} + +// Positivef asserts that the specified element is positive +// +// assert.Positivef(t, 1, "error message %s", "formatted") +// assert.Positivef(t, 1.23, "error message %s", "formatted") +func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Positive(t, e, append([]interface{}{msg}, args...)...) +} + +// Truef asserts that the specified value is true. +// +// assert.Truef(t, myBool, "error message %s", "formatted") +func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return True(t, value, append([]interface{}{msg}, args...)...) +} + +// WithinDurationf asserts that the two times are within duration delta of each other. +// +// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...) +} diff --git a/drivers/mongov2/internal/assert/assertion_mongo.go b/drivers/mongov2/internal/assert/assertion_mongo.go new file mode 100644 index 0000000..e47fdf9 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_mongo.go @@ -0,0 +1,126 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// assertion_mongo.go contains MongoDB-specific extensions to the "assert" +// package. + +package assert + +import ( + "context" + "fmt" + "reflect" + "time" + "unsafe" +) + +// DifferentAddressRanges asserts that two byte slices reference distinct memory +// address ranges, meaning they reference different underlying byte arrays. +func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if len(a) == 0 || len(b) == 0 { + return true + } + + // Find the start and end memory addresses for the underlying byte array for + // each input byte slice. + sliceAddrRange := func(b []byte) (uintptr, uintptr) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + return sh.Data, sh.Data + uintptr(sh.Cap-1) + } + aStart, aEnd := sliceAddrRange(a) + bStart, bEnd := sliceAddrRange(b) + + // If "b" starts after "a" ends or "a" starts after "b" ends, there is no + // overlap. + if bStart > aEnd || aStart > bEnd { + return true + } + + // Otherwise, calculate the overlap start and end and print the memory + // overlap error message. + min := func(a, b uintptr) uintptr { + if a < b { + return a + } + return b + } + max := func(a, b uintptr) uintptr { + if a > b { + return a + } + return b + } + overlapLow := max(aStart, bStart) + overlapHigh := min(aEnd, bEnd) + + t.Errorf("Byte slices point to the same underlying byte array:\n"+ + "\ta addresses:\t%d ... %d\n"+ + "\tb addresses:\t%d ... %d\n"+ + "\toverlap:\t%d ... %d", + aStart, aEnd, + bStart, bEnd, + overlapLow, overlapHigh) + + return false +} + +// EqualBSON asserts that the expected and actual BSON binary values are equal. +// If the values are not equal, it prints both the binary and Extended JSON diff +// of the BSON values. The provided BSON value types must implement the +// fmt.Stringer interface. +func EqualBSON(t TestingT, expected, actual interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return Equal(t, + expected, + actual, + `expected and actual BSON values do not match +As Extended JSON: +Expected: %s +Actual : %s`, + expected.(fmt.Stringer).String(), + actual.(fmt.Stringer).String()) +} + +// Soon runs the provided callback and fails the passed-in test if the callback +// does not complete within timeout. The provided callback should respect the +// passed-in context and cease execution when it has expired. +// +// Deprecated: This function will be removed with GODRIVER-2667, use +// assert.Eventually instead. +func Soon(t TestingT, callback func(ctx context.Context), timeout time.Duration) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + // Create context to manually cancel callback after Soon assertion. + callbackCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + fullCallback := func() { + callback(callbackCtx) + done <- struct{}{} + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + go fullCallback() + + select { + case <-done: + return + case <-timer.C: + t.Errorf("timed out in %s waiting for callback", timeout) + } +} diff --git a/drivers/mongov2/internal/assert/assertion_mongo_test.go b/drivers/mongov2/internal/assert/assertion_mongo_test.go new file mode 100644 index 0000000..9fe6f48 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertion_mongo_test.go @@ -0,0 +1,125 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package assert + +import ( + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestDifferentAddressRanges(t *testing.T) { + t.Parallel() + + slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + testCases := []struct { + name string + a []byte + b []byte + want bool + }{ + { + name: "distinct byte slices", + a: []byte{0, 1, 2, 3}, + b: []byte{0, 1, 2, 3}, + want: true, + }, + { + name: "same byte slice", + a: slice, + b: slice, + want: false, + }, + { + name: "whole and subslice", + a: slice, + b: slice[:4], + want: false, + }, + { + name: "two subslices", + a: slice[1:2], + b: slice[3:4], + want: false, + }, + { + name: "empty", + a: []byte{0, 1, 2, 3}, + b: []byte{}, + want: true, + }, + { + name: "nil", + a: []byte{0, 1, 2, 3}, + b: nil, + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := DifferentAddressRanges(new(testing.T), tc.a, tc.b) + if got != tc.want { + t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want) + } + }) + } +} + +func TestEqualBSON(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expected interface{} + actual interface{} + want bool + }{ + { + name: "equal bson.Raw", + expected: bson.Raw{5, 0, 0, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: true, + }, + { + name: "different bson.Raw", + expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "invalid bson.Raw", + expected: bson.Raw{99, 99, 99, 99}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "nil bson.Raw", + expected: bson.Raw(nil), + actual: bson.Raw(nil), + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := EqualBSON(new(testing.T), tc.expected, tc.actual) + if got != tc.want { + t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) + } + }) + } +} diff --git a/drivers/mongov2/internal/assert/assertions.go b/drivers/mongov2/internal/assert/assertions.go new file mode 100644 index 0000000..c227d47 --- /dev/null +++ b/drivers/mongov2/internal/assert/assertions.go @@ -0,0 +1,1075 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertions.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "math" + "path/filepath" + "reflect" + "runtime" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/davecgh/go-spew/spew" +) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Errorf(format string, args ...interface{}) +} + +// ObjectsAreEqual determines if two objects are considered equal. +// +// This function does no assertion of any kind. +func ObjectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} + +// ObjectsAreEqualValues gets whether two objects are equal, or if their +// values are equal. +func ObjectsAreEqualValues(expected, actual interface{}) bool { + if ObjectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +} + +/* CallerInfo is necessary because the assert functions use the testing object +internally, causing it to print the file:line of the assert method, rather than where +the problem actually occurred in calling code.*/ + +// CallerInfo returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallerInfo() []string { + + var pc uintptr + var ok bool + var file string + var line int + var name string + + callers := []string{} + for i := 0; ; i++ { + pc, file, line, ok = runtime.Caller(i) + if !ok { + // The breaks below failed to terminate the loop, and we ran off the + // end of the call stack. + break + } + + // This is a huge edge case, but it will panic if this is the case, see #180 + if file == "" { + break + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + + // testing.tRunner is the standard library function that calls + // tests. Subtests are called directly by tRunner, without going through + // the Test/Benchmark/Example function that contains the t.Run calls, so + // with subtests we should break when we hit tRunner, without adding it + // to the list of callers. + if name == "testing.tRunner" { + break + } + + parts := strings.Split(file, "/") + file = parts[len(parts)-1] + if len(parts) > 1 { + dir := parts[len(parts)-2] + if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { + path, _ := filepath.Abs(file) + callers = append(callers, fmt.Sprintf("%s:%d", path, line)) + } + } + + // Drop the package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if isTest(name, "Test") || + isTest(name, "Benchmark") || + isTest(name, "Example") { + break + } + } + + return callers +} + +// Stolen from the `go test` tool. +// isTest tells whether name looks like a test (or benchmark, according to prefix). +// It is a Test (say) if there is a character after Test that is not a lower-case letter. +// We don't want TesticularCancer. +func isTest(name, prefix string) bool { + if !strings.HasPrefix(name, prefix) { + return false + } + if len(name) == len(prefix) { // "Test" is ok + return true + } + r, _ := utf8.DecodeRuneInString(name[len(prefix):]) + return !unicode.IsLower(r) +} + +func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + msg := msgAndArgs[0] + if msgAsStr, ok := msg.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", msg) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} + +// Aligns the provided message so that all lines after the first line start at the same location as the first line. +// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). +// The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the +// basis on which the alignment occurs). +func indentMessageLines(message string, longestLabelLen int) string { + outBuf := new(bytes.Buffer) + + for i, scanner := 0, bufio.NewScanner(strings.NewReader(message)); scanner.Scan(); i++ { + // no need to align first line because it starts at the correct location (after the label) + if i != 0 { + // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab + outBuf.WriteString("\n\t" + strings.Repeat(" ", longestLabelLen+1) + "\t") + } + outBuf.WriteString(scanner.Text()) + } + + return outBuf.String() +} + +type failNower interface { + FailNow() +} + +// FailNow fails test +func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, failureMessage, msgAndArgs...) + + // We cannot extend TestingT with FailNow() and + // maintain backwards compatibility, so we fallback + // to panicking when FailNow is not available in + // TestingT. + // See issue #263 + + if t, ok := t.(failNower); ok { + t.FailNow() + } else { + panic("test failed and t is missing `FailNow()`") + } + return false +} + +// Fail reports a failure through +func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + content := []labeledContent{ + {"Error Trace", strings.Join(CallerInfo(), "\n\t\t\t")}, + {"Error", failureMessage}, + } + + // Add test name if the Go version supports it + if n, ok := t.(interface { + Name() string + }); ok { + content = append(content, labeledContent{"Test", n.Name()}) + } + + message := messageFromMsgAndArgs(msgAndArgs...) + if len(message) > 0 { + content = append(content, labeledContent{"Messages", message}) + } + + t.Errorf("\n%s", ""+labeledOutput(content...)) + + return false +} + +type labeledContent struct { + label string + content string +} + +// labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner: +// +// \t{{label}}:{{align_spaces}}\t{{content}}\n +// +// The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label. +// If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this +// alignment is achieved, "\t{{content}}\n" is added for the output. +// +// If the content of the labeledOutput contains line breaks, the subsequent lines are aligned so that they start at the same location as the first line. +func labeledOutput(content ...labeledContent) string { + longestLabel := 0 + for _, v := range content { + if len(v.label) > longestLabel { + longestLabel = len(v.label) + } + } + var output string + for _, v := range content { + output += "\t" + v.label + ":" + strings.Repeat(" ", longestLabel-len(v.label)) + "\t" + indentMessageLines(v.content, longestLabel) + "\n" + } + return output +} + +// IsType asserts that the specified objects are of the same type. +func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { + return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) + } + + return true +} + +// Equal asserts that two objects are equal. +// +// assert.Equal(t, 123, 123) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if err := validateEqualArgs(expected, actual); err != nil { + return Fail(t, fmt.Sprintf("Invalid operation: %#v == %#v (%s)", + expected, actual, err), msgAndArgs...) + } + + if !ObjectsAreEqual(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true + +} + +// validateEqualArgs checks whether provided arguments can be safely used in the +// Equal/NotEqual functions. +func validateEqualArgs(expected, actual interface{}) error { + if expected == nil && actual == nil { + return nil + } + + if isFunction(expected) || isFunction(actual) { + return errors.New("cannot take func type as argument") + } + return nil +} + +// formatUnequalValues takes two values of arbitrary types and returns string +// representations appropriate to be presented to the user. +// +// If the values are not of like type, the returned strings will be prefixed +// with the type name, and the value will be enclosed in parenthesis similar +// to a type conversion in the Go grammar. +func formatUnequalValues(expected, actual interface{}) (e string, a string) { + if reflect.TypeOf(expected) != reflect.TypeOf(actual) { + return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected)), + fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual)) + } + switch expected.(type) { + case time.Duration: + return fmt.Sprintf("%v", expected), fmt.Sprintf("%v", actual) + } + return truncatingFormat(expected), truncatingFormat(actual) +} + +// truncatingFormat formats the data and truncates it if it's too long. +// +// This helps keep formatted error messages lines from exceeding the +// bufio.MaxScanTokenSize max line length that the go testing framework imposes. +func truncatingFormat(data interface{}) string { + value := fmt.Sprintf("%#v", data) + max := bufio.MaxScanTokenSize - 100 // Give us some space the type info too if needed. + if len(value) > max { + value = value[0:max] + "<... truncated>" + } + return value +} + +// EqualValues asserts that two objects are equal or convertible to the same types +// and equal. +// +// assert.EqualValues(t, uint32(123), int32(123)) +func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if !ObjectsAreEqualValues(expected, actual) { + diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s%s", expected, actual, diff), msgAndArgs...) + } + + return true + +} + +// NotNil asserts that the specified object is not nil. +// +// assert.NotNil(t, err) +func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if !isNil(object) { + return true + } + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Expected value not to be nil.", msgAndArgs...) +} + +// containsKind checks if a specified kind in the slice of kinds. +func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool { + for i := 0; i < len(kinds); i++ { + if kind == kinds[i] { + return true + } + } + + return false +} + +// isNil checks if a specified object is nil or not, without Failing. +func isNil(object interface{}) bool { + if object == nil { + return true + } + + value := reflect.ValueOf(object) + kind := value.Kind() + isNilableKind := containsKind( + []reflect.Kind{ + reflect.Chan, reflect.Func, + reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice}, + kind) + + if isNilableKind && value.IsNil() { + return true + } + + return false +} + +// Nil asserts that the specified object is nil. +// +// assert.Nil(t, err) +func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if isNil(object) { + return true + } + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...) +} + +// getLen try to get length of object. +// return (false, 0) if impossible. +func getLen(x interface{}) (ok bool, length int) { + v := reflect.ValueOf(x) + defer func() { + if e := recover(); e != nil { + ok = false + } + }() + return true, v.Len() +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// assert.Len(t, mySlice, 3) +func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + ok, l := getLen(object) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) + } + + if l != length { + return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) + } + return true +} + +// True asserts that the specified value is true. +// +// assert.True(t, myBool) +func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { + if !value { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Should be true", msgAndArgs...) + } + + return true + +} + +// False asserts that the specified value is false. +// +// assert.False(t, myBool) +func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { + if value { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "Should be false", msgAndArgs...) + } + + return true + +} + +// NotEqual asserts that the specified values are NOT equal. +// +// assert.NotEqual(t, obj1, obj2) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if err := validateEqualArgs(expected, actual); err != nil { + return Fail(t, fmt.Sprintf("Invalid operation: %#v != %#v (%s)", + expected, actual, err), msgAndArgs...) + } + + if ObjectsAreEqual(expected, actual) { + return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) + } + + return true + +} + +// NotEqualValues asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValues(t, obj1, obj2) +func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if ObjectsAreEqualValues(expected, actual) { + return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) + } + + return true +} + +// containsElement try loop over the list check if the list includes the element. +// return (false, false) if impossible. +// return (true, false) if element was not found. +// return (true, true) if element was found. +func containsElement(list interface{}, element interface{}) (ok, found bool) { + + listValue := reflect.ValueOf(list) + listType := reflect.TypeOf(list) + if listType == nil { + return false, false + } + listKind := listType.Kind() + defer func() { + if e := recover(); e != nil { + ok = false + found = false + } + }() + + if listKind == reflect.String { + elementValue := reflect.ValueOf(element) + return true, strings.Contains(listValue.String(), elementValue.String()) + } + + if listKind == reflect.Map { + mapKeys := listValue.MapKeys() + for i := 0; i < len(mapKeys); i++ { + if ObjectsAreEqual(mapKeys[i].Interface(), element) { + return true, true + } + } + return true, false + } + + for i := 0; i < listValue.Len(); i++ { + if ObjectsAreEqual(listValue.Index(i).Interface(), element) { + return true, true + } + } + return true, false + +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Contains(t, "Hello World", "World") +// assert.Contains(t, ["Hello", "World"], "World") +// assert.Contains(t, {"Hello": "World"}, "Hello") +func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ok, found := containsElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...) + } + if !found { + return Fail(t, fmt.Sprintf("%#v does not contain %#v", s, contains), msgAndArgs...) + } + + return true + +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContains(t, "Hello World", "Earth") +// assert.NotContains(t, ["Hello", "World"], "Earth") +// assert.NotContains(t, {"Hello": "World"}, "Earth") +func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ok, found := containsElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) + } + if found { + return Fail(t, fmt.Sprintf("\"%s\" should not contain \"%s\"", s, contains), msgAndArgs...) + } + + return true + +} + +// isEmpty gets whether the specified object is considered empty or not. +func isEmpty(object interface{}) bool { + + // get nil case out of the way + if object == nil { + return true + } + + objValue := reflect.ValueOf(object) + + switch objValue.Kind() { + // collection types are empty when they have no element + case reflect.Chan, reflect.Map, reflect.Slice: + return objValue.Len() == 0 + // pointers are empty if nil or if the value they point to is empty + case reflect.Ptr: + if objValue.IsNil() { + return true + } + deref := objValue.Elem().Interface() + return isEmpty(deref) + // for all other types, compare against the zero value + // array types are empty when they match their zero-initialized state + default: + zero := reflect.Zero(objValue.Type()) + return reflect.DeepEqual(object, zero.Interface()) + } +} + +// ElementsMatch asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) +func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if isEmpty(listA) && isEmpty(listB) { + return true + } + + if !isList(t, listA, msgAndArgs...) || !isList(t, listB, msgAndArgs...) { + return false + } + + extraA, extraB := diffLists(listA, listB) + + if len(extraA) == 0 && len(extraB) == 0 { + return true + } + + return Fail(t, formatListDiff(listA, listB, extraA, extraB), msgAndArgs...) +} + +// isList checks that the provided value is array or slice. +func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) { + kind := reflect.TypeOf(list).Kind() + if kind != reflect.Array && kind != reflect.Slice { + return Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind), + msgAndArgs...) + } + return true +} + +// diffLists diffs two arrays/slices and returns slices of elements that are only in A and only in B. +// If some element is present multiple times, each instance is counted separately (e.g. if something is 2x in A and +// 5x in B, it will be 0x in extraA and 3x in extraB). The order of items in both lists is ignored. +func diffLists(listA, listB interface{}) ([]interface{}, []interface{}) { + var extraA, extraB []interface{} + + aValue := reflect.ValueOf(listA) + bValue := reflect.ValueOf(listB) + + aLen := aValue.Len() + bLen := bValue.Len() + + // Mark indexes in bValue that we already used + visited := make([]bool, bLen) + for i := 0; i < aLen; i++ { + element := aValue.Index(i).Interface() + found := false + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + if ObjectsAreEqual(bValue.Index(j).Interface(), element) { + visited[j] = true + found = true + break + } + } + if !found { + extraA = append(extraA, element) + } + } + + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + extraB = append(extraB, bValue.Index(j).Interface()) + } + + return extraA, extraB +} + +func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) string { + var msg bytes.Buffer + + msg.WriteString("elements differ") + if len(extraA) > 0 { + msg.WriteString("\n\nextra elements in list A:\n") + msg.WriteString(spewConfig.Sdump(extraA)) + } + if len(extraB) > 0 { + msg.WriteString("\n\nextra elements in list B:\n") + msg.WriteString(spewConfig.Sdump(extraB)) + } + msg.WriteString("\n\nlistA:\n") + msg.WriteString(spewConfig.Sdump(listA)) + msg.WriteString("\n\nlistB:\n") + msg.WriteString(spewConfig.Sdump(listB)) + + return msg.String() +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + dt := expected.Sub(actual) + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +func toFloat(x interface{}) (float64, bool) { + var xf float64 + xok := true + + switch xn := x.(type) { + case uint: + xf = float64(xn) + case uint8: + xf = float64(xn) + case uint16: + xf = float64(xn) + case uint32: + xf = float64(xn) + case uint64: + xf = float64(xn) + case int: + xf = float64(xn) + case int8: + xf = float64(xn) + case int16: + xf = float64(xn) + case int32: + xf = float64(xn) + case int64: + xf = float64(xn) + case float32: + xf = float64(xn) + case float64: + xf = xn + case time.Duration: + xf = float64(xn) + default: + xok = false + } + + return xf, xok +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + af, aok := toFloat(expected) + bf, bok := toFloat(actual) + + if !aok || !bok { + return Fail(t, "Parameters must be numerical", msgAndArgs...) + } + + if math.IsNaN(af) && math.IsNaN(bf) { + return true + } + + if math.IsNaN(af) { + return Fail(t, "Expected must not be NaN", msgAndArgs...) + } + + if math.IsNaN(bf) { + return Fail(t, fmt.Sprintf("Expected %v with delta %v, but was NaN", expected, delta), msgAndArgs...) + } + + dt := af - bf + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +/* + Errors +*/ + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoError(t, err) { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { + if err != nil { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) + } + + return true +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err) { +// assert.Equal(t, expectedError, err) +// } +func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { + if err == nil { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, "An error is expected but got nil.", msgAndArgs...) + } + + return true +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualError(t, err, expectedErrorString) +func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !Error(t, theError, msgAndArgs...) { + return false + } + expected := errString + actual := theError.Error() + // don't need to use deep equals here, we know they are both strings + if expected != actual { + return Fail(t, fmt.Sprintf("Error message not equal:\n"+ + "expected: %q\n"+ + "actual : %q", expected, actual), msgAndArgs...) + } + return true +} + +// ErrorIs asserts that at least one of the errors in err's chain matches target. +// This is a wrapper for errors.Is. +func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if errors.Is(err, target) { + return true + } + + var expectedText string + if target != nil { + expectedText = target.Error() + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+ + "expected: %q\n"+ + "in chain: %s", expectedText, chain, + ), msgAndArgs...) +} + +// ErrorContains asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContains(t, err, expectedErrorSubString) +func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !Error(t, theError, msgAndArgs...) { + return false + } + + actual := theError.Error() + if !strings.Contains(actual, contains) { + return Fail(t, fmt.Sprintf("Error %#v does not contain %#v", actual, contains), msgAndArgs...) + } + + return true +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice, array or string. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array && ek != reflect.String { + return "" + } + + var e, a string + + switch et { + case reflect.TypeOf(""): + e = reflect.ValueOf(expected).String() + a = reflect.ValueOf(actual).String() + case reflect.TypeOf(time.Time{}): + e = spewConfigStringerEnabled.Sdump(expected) + a = spewConfigStringerEnabled.Sdump(actual) + default: + e = spewConfig.Sdump(expected) + a = spewConfig.Sdump(actual) + } + + diff, _ := GetUnifiedDiffString(UnifiedDiff{ + A: SplitLines(e), + B: SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return "\n\nDiff:\n" + diff +} + +func isFunction(arg interface{}) bool { + if arg == nil { + return false + } + return reflect.TypeOf(arg).Kind() == reflect.Func +} + +var spewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + DisableMethods: true, + MaxDepth: 10, +} + +var spewConfigStringerEnabled = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, + MaxDepth: 10, +} + +type tHelper interface { + Helper() +} + +// Eventually asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ch := make(chan bool, 1) + + timer := time.NewTimer(waitFor) + defer timer.Stop() + + ticker := time.NewTicker(tick) + defer ticker.Stop() + + for tick := ticker.C; ; { + select { + case <-timer.C: + return Fail(t, "Condition never satisfied", msgAndArgs...) + case <-tick: + tick = nil + go func() { ch <- condition() }() + case v := <-ch: + if v { + return true + } + tick = ticker.C + } + } +} + +func buildErrorChainString(err error) string { + if err == nil { + return "" + } + + e := errors.Unwrap(err) + chain := fmt.Sprintf("%q", err.Error()) + for e != nil { + chain += fmt.Sprintf("\n\t%q", e.Error()) + e = errors.Unwrap(e) + } + return chain +} diff --git a/drivers/mongov2/internal/assert/assertions_test.go b/drivers/mongov2/internal/assert/assertions_test.go new file mode 100644 index 0000000..6c03bbd --- /dev/null +++ b/drivers/mongov2/internal/assert/assertions_test.go @@ -0,0 +1,1231 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertions_test.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "math" + "reflect" + "runtime" + "strings" + "testing" + "time" +) + +// AssertionTesterInterface defines an interface to be used for testing assertion methods +type AssertionTesterInterface interface { + TestMethod() +} + +// AssertionTesterConformingObject is an object that conforms to the AssertionTesterInterface interface +type AssertionTesterConformingObject struct { +} + +func (a *AssertionTesterConformingObject) TestMethod() { +} + +// AssertionTesterNonConformingObject is an object that does not conform to the AssertionTesterInterface interface +type AssertionTesterNonConformingObject struct { +} + +func TestObjectsAreEqual(t *testing.T) { + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + // cases that are expected to be equal + {"Hello World", "Hello World", true}, + {123, 123, true}, + {123.5, 123.5, true}, + {[]byte("Hello World"), []byte("Hello World"), true}, + {nil, nil, true}, + + // cases that are expected not to be equal + {map[int]int{5: 10}, map[int]int{10: 20}, false}, + {'x', "x", false}, + {"x", 'x', false}, + {0, 0.1, false}, + {0.1, 0, false}, + {time.Now, time.Now, false}, + {func() {}, func() {}, false}, + {uint32(10), int32(10), false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := ObjectsAreEqual(c.expected, c.actual) + + if res != c.result { + t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + + }) + } + + // Cases where type differ but values are equal + if !ObjectsAreEqualValues(uint32(10), int32(10)) { + t.Error("ObjectsAreEqualValues should return true") + } + if ObjectsAreEqualValues(0, nil) { + t.Fail() + } + if ObjectsAreEqualValues(nil, 0) { + t.Fail() + } + +} + +func TestIsType(t *testing.T) { + + mockT := new(testing.T) + + if !IsType(mockT, new(AssertionTesterConformingObject), new(AssertionTesterConformingObject)) { + t.Error("IsType should return true: AssertionTesterConformingObject is the same type as AssertionTesterConformingObject") + } + if IsType(mockT, new(AssertionTesterConformingObject), new(AssertionTesterNonConformingObject)) { + t.Error("IsType should return false: AssertionTesterConformingObject is not the same type as AssertionTesterNonConformingObject") + } + +} + +func TestEqual(t *testing.T) { + type myType string + + mockT := new(testing.T) + var m map[string]interface{} + + cases := []struct { + expected interface{} + actual interface{} + result bool + remark string + }{ + {"Hello World", "Hello World", true, ""}, + {123, 123, true, ""}, + {123.5, 123.5, true, ""}, + {[]byte("Hello World"), []byte("Hello World"), true, ""}, + {nil, nil, true, ""}, + {int32(123), int32(123), true, ""}, + {uint64(123), uint64(123), true, ""}, + {myType("1"), myType("1"), true, ""}, + {&struct{}{}, &struct{}{}, true, "pointer equality is based on equality of underlying value"}, + + // Not expected to be equal + {m["bar"], "something", false, ""}, + {myType("1"), myType("2"), false, ""}, + + // A case that might be confusing, especially with numeric literals + {10, uint(10), false, ""}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("Equal(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := Equal(mockT, c.expected, c.actual) + + if res != c.result { + t.Errorf("Equal(%#v, %#v) should return %#v: %s", c.expected, c.actual, c.result, c.remark) + } + }) + } +} + +// bufferT implements TestingT. Its implementation of Errorf writes the output that would be produced by +// testing.T.Errorf to an internal bytes.Buffer. +type bufferT struct { + buf bytes.Buffer +} + +func (t *bufferT) Errorf(format string, args ...interface{}) { + // implementation of decorate is copied from testing.T + decorate := func(s string) string { + _, file, line, ok := runtime.Caller(3) // decorate + log + public function. + if ok { + // Truncate file name at last file name separator. + if index := strings.LastIndex(file, "/"); index >= 0 { + file = file[index+1:] + } else if index = strings.LastIndex(file, "\\"); index >= 0 { + file = file[index+1:] + } + } else { + file = "???" + line = 1 + } + buf := new(bytes.Buffer) + // Every line is indented at least one tab. + buf.WriteByte('\t') + fmt.Fprintf(buf, "%s:%d: ", file, line) + lines := strings.Split(s, "\n") + if l := len(lines); l > 1 && lines[l-1] == "" { + lines = lines[:l-1] + } + for i, line := range lines { + if i > 0 { + // Second and subsequent lines are indented an extra tab. + buf.WriteString("\n\t\t") + } + buf.WriteString(line) + } + buf.WriteByte('\n') + return buf.String() + } + t.buf.WriteString(decorate(fmt.Sprintf(format, args...))) +} + +func TestStringEqual(_ *testing.T) { + for _, currCase := range []struct { + equalWant string + equalGot string + msgAndArgs []interface{} + want string + }{ + {equalWant: "hi, \nmy name is", equalGot: "what,\nmy name is", want: "\tassertions.go:\\d+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"hi, \\\\nmy name is\"\n\\s+actual\\s+: \"what,\\\\nmy name is\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1,2 \\+1,2 @@\n\\s+-hi, \n\\s+\\+what,\n\\s+my name is"}, + } { + mockT := &bufferT{} + Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...) + } +} + +func TestEqualFormatting(_ *testing.T) { + for _, currCase := range []struct { + equalWant string + equalGot string + msgAndArgs []interface{} + want string + }{ + {equalWant: "want", equalGot: "got", want: "\tassertions.go:\\d+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n"}, + {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+hello, world!\n"}, + {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{123}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+123\n"}, + {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{struct{ a string }{"hello"}}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+{a:hello}\n"}, + } { + mockT := &bufferT{} + Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...) + } +} + +func TestFormatUnequalValues(t *testing.T) { + expected, actual := formatUnequalValues("foo", "bar") + Equal(t, `"foo"`, expected, "value should not include type") + Equal(t, `"bar"`, actual, "value should not include type") + + expected, actual = formatUnequalValues(123, 123) + Equal(t, `123`, expected, "value should not include type") + Equal(t, `123`, actual, "value should not include type") + + expected, actual = formatUnequalValues(int64(123), int32(123)) + Equal(t, `int64(123)`, expected, "value should include type") + Equal(t, `int32(123)`, actual, "value should include type") + + expected, actual = formatUnequalValues(int64(123), nil) + Equal(t, `int64(123)`, expected, "value should include type") + Equal(t, `()`, actual, "value should include type") + + type testStructType struct { + Val string + } + + expected, actual = formatUnequalValues(&testStructType{Val: "test"}, &testStructType{Val: "test"}) + Equal(t, `&assert.testStructType{Val:"test"}`, expected, "value should not include type annotation") + Equal(t, `&assert.testStructType{Val:"test"}`, actual, "value should not include type annotation") +} + +func TestNotNil(t *testing.T) { + + mockT := new(testing.T) + + if !NotNil(mockT, new(AssertionTesterConformingObject)) { + t.Error("NotNil should return true: object is not nil") + } + if NotNil(mockT, nil) { + t.Error("NotNil should return false: object is nil") + } + if NotNil(mockT, (*struct{})(nil)) { + t.Error("NotNil should return false: object is (*struct{})(nil)") + } + +} + +func TestNil(t *testing.T) { + + mockT := new(testing.T) + + if !Nil(mockT, nil) { + t.Error("Nil should return true: object is nil") + } + if !Nil(mockT, (*struct{})(nil)) { + t.Error("Nil should return true: object is (*struct{})(nil)") + } + if Nil(mockT, new(AssertionTesterConformingObject)) { + t.Error("Nil should return false: object is not nil") + } + +} + +func TestTrue(t *testing.T) { + + mockT := new(testing.T) + + if !True(mockT, true) { + t.Error("True should return true") + } + if True(mockT, false) { + t.Error("True should return false") + } + +} + +func TestFalse(t *testing.T) { + + mockT := new(testing.T) + + if !False(mockT, false) { + t.Error("False should return true") + } + if False(mockT, true) { + t.Error("False should return false") + } + +} + +func TestNotEqual(t *testing.T) { + + mockT := new(testing.T) + + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + // cases that are expected not to match + {"Hello World", "Hello World!", true}, + {123, 1234, true}, + {123.5, 123.55, true}, + {[]byte("Hello World"), []byte("Hello World!"), true}, + {nil, new(AssertionTesterConformingObject), true}, + + // cases that are expected to match + {nil, nil, false}, + {"Hello World", "Hello World", false}, + {123, 123, false}, + {123.5, 123.5, false}, + {[]byte("Hello World"), []byte("Hello World"), false}, + {new(AssertionTesterConformingObject), new(AssertionTesterConformingObject), false}, + {&struct{}{}, &struct{}{}, false}, + {func() int { return 23 }, func() int { return 24 }, false}, + // A case that might be confusing, especially with numeric literals + {int(10), uint(10), true}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("NotEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := NotEqual(mockT, c.expected, c.actual) + + if res != c.result { + t.Errorf("NotEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + }) + } +} + +func TestNotEqualValues(t *testing.T) { + mockT := new(testing.T) + + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + // cases that are expected not to match + {"Hello World", "Hello World!", true}, + {123, 1234, true}, + {123.5, 123.55, true}, + {[]byte("Hello World"), []byte("Hello World!"), true}, + {nil, new(AssertionTesterConformingObject), true}, + + // cases that are expected to match + {nil, nil, false}, + {"Hello World", "Hello World", false}, + {123, 123, false}, + {123.5, 123.5, false}, + {[]byte("Hello World"), []byte("Hello World"), false}, + {new(AssertionTesterConformingObject), new(AssertionTesterConformingObject), false}, + {&struct{}{}, &struct{}{}, false}, + + // Different behaviour from NotEqual() + {func() int { return 23 }, func() int { return 24 }, true}, + {int(10), int(11), true}, + {int(10), uint(10), false}, + + {struct{}{}, struct{}{}, false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("NotEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := NotEqualValues(mockT, c.expected, c.actual) + + if res != c.result { + t.Errorf("NotEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result) + } + }) + } +} + +func TestContainsNotContains(t *testing.T) { + + type A struct { + Name, Value string + } + list := []string{"Foo", "Bar"} + + complexList := []*A{ + {"b", "c"}, + {"d", "e"}, + {"g", "h"}, + {"j", "k"}, + } + simpleMap := map[interface{}]interface{}{"Foo": "Bar"} + var zeroMap map[interface{}]interface{} + + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + {"Hello World", "Hello", true}, + {"Hello World", "Salut", false}, + {list, "Bar", true}, + {list, "Salut", false}, + {complexList, &A{"g", "h"}, true}, + {complexList, &A{"g", "e"}, false}, + {simpleMap, "Foo", true}, + {simpleMap, "Bar", false}, + {zeroMap, "Bar", false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("Contains(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + mockT := new(testing.T) + res := Contains(mockT, c.expected, c.actual) + + if res != c.result { + if res { + t.Errorf("Contains(%#v, %#v) should return true:\n\t%#v contains %#v", c.expected, c.actual, c.expected, c.actual) + } else { + t.Errorf("Contains(%#v, %#v) should return false:\n\t%#v does not contain %#v", c.expected, c.actual, c.expected, c.actual) + } + } + }) + } + + for _, c := range cases { + t.Run(fmt.Sprintf("NotContains(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + mockT := new(testing.T) + res := NotContains(mockT, c.expected, c.actual) + + // NotContains should be inverse of Contains. If it's not, something is wrong + if res == Contains(mockT, c.expected, c.actual) { + if res { + t.Errorf("NotContains(%#v, %#v) should return true:\n\t%#v does not contains %#v", c.expected, c.actual, c.expected, c.actual) + } else { + t.Errorf("NotContains(%#v, %#v) should return false:\n\t%#v contains %#v", c.expected, c.actual, c.expected, c.actual) + } + } + }) + } +} + +func TestContainsFailMessage(t *testing.T) { + + mockT := new(mockTestingT) + + Contains(mockT, "Hello World", errors.New("Hello")) + expectedFail := "\"Hello World\" does not contain &errors.errorString{s:\"Hello\"}" + actualFail := mockT.errorString() + if !strings.Contains(actualFail, expectedFail) { + t.Errorf("Contains failure should include %q but was %q", expectedFail, actualFail) + } +} + +func TestContainsNotContainsOnNilValue(t *testing.T) { + mockT := new(mockTestingT) + + Contains(mockT, nil, "key") + expectedFail := " could not be applied builtin len()" + actualFail := mockT.errorString() + if !strings.Contains(actualFail, expectedFail) { + t.Errorf("Contains failure should include %q but was %q", expectedFail, actualFail) + } + + NotContains(mockT, nil, "key") + if !strings.Contains(actualFail, expectedFail) { + t.Errorf("Contains failure should include %q but was %q", expectedFail, actualFail) + } +} + +func Test_containsElement(t *testing.T) { + + list1 := []string{"Foo", "Bar"} + list2 := []int{1, 2} + simpleMap := map[interface{}]interface{}{"Foo": "Bar"} + + ok, found := containsElement("Hello World", "World") + True(t, ok) + True(t, found) + + ok, found = containsElement(list1, "Foo") + True(t, ok) + True(t, found) + + ok, found = containsElement(list1, "Bar") + True(t, ok) + True(t, found) + + ok, found = containsElement(list2, 1) + True(t, ok) + True(t, found) + + ok, found = containsElement(list2, 2) + True(t, ok) + True(t, found) + + ok, found = containsElement(list1, "Foo!") + True(t, ok) + False(t, found) + + ok, found = containsElement(list2, 3) + True(t, ok) + False(t, found) + + ok, found = containsElement(list2, "1") + True(t, ok) + False(t, found) + + ok, found = containsElement(simpleMap, "Foo") + True(t, ok) + True(t, found) + + ok, found = containsElement(simpleMap, "Bar") + True(t, ok) + False(t, found) + + ok, found = containsElement(1433, "1") + False(t, ok) + False(t, found) +} + +func TestElementsMatch(t *testing.T) { + mockT := new(testing.T) + + cases := []struct { + expected interface{} + actual interface{} + result bool + }{ + // matching + {nil, nil, true}, + + {nil, nil, true}, + {[]int{}, []int{}, true}, + {[]int{1}, []int{1}, true}, + {[]int{1, 1}, []int{1, 1}, true}, + {[]int{1, 2}, []int{1, 2}, true}, + {[]int{1, 2}, []int{2, 1}, true}, + {[2]int{1, 2}, [2]int{2, 1}, true}, + {[]string{"hello", "world"}, []string{"world", "hello"}, true}, + {[]string{"hello", "hello"}, []string{"hello", "hello"}, true}, + {[]string{"hello", "hello", "world"}, []string{"hello", "world", "hello"}, true}, + {[3]string{"hello", "hello", "world"}, [3]string{"hello", "world", "hello"}, true}, + {[]int{}, nil, true}, + + // not matching + {[]int{1}, []int{1, 1}, false}, + {[]int{1, 2}, []int{2, 2}, false}, + {[]string{"hello", "hello"}, []string{"hello"}, false}, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("ElementsMatch(%#v, %#v)", c.expected, c.actual), func(t *testing.T) { + res := ElementsMatch(mockT, c.actual, c.expected) + + if res != c.result { + t.Errorf("ElementsMatch(%#v, %#v) should return %v", c.actual, c.expected, c.result) + } + }) + } +} + +func TestDiffLists(t *testing.T) { + tests := []struct { + name string + listA interface{} + listB interface{} + extraA []interface{} + extraB []interface{} + }{ + { + name: "equal empty", + listA: []string{}, + listB: []string{}, + extraA: nil, + extraB: nil, + }, + { + name: "equal same order", + listA: []string{"hello", "world"}, + listB: []string{"hello", "world"}, + extraA: nil, + extraB: nil, + }, + { + name: "equal different order", + listA: []string{"hello", "world"}, + listB: []string{"world", "hello"}, + extraA: nil, + extraB: nil, + }, + { + name: "extra A", + listA: []string{"hello", "hello", "world"}, + listB: []string{"hello", "world"}, + extraA: []interface{}{"hello"}, + extraB: nil, + }, + { + name: "extra A twice", + listA: []string{"hello", "hello", "hello", "world"}, + listB: []string{"hello", "world"}, + extraA: []interface{}{"hello", "hello"}, + extraB: nil, + }, + { + name: "extra B", + listA: []string{"hello", "world"}, + listB: []string{"hello", "hello", "world"}, + extraA: nil, + extraB: []interface{}{"hello"}, + }, + { + name: "extra B twice", + listA: []string{"hello", "world"}, + listB: []string{"hello", "hello", "world", "hello"}, + extraA: nil, + extraB: []interface{}{"hello", "hello"}, + }, + { + name: "integers 1", + listA: []int{1, 2, 3, 4, 5}, + listB: []int{5, 4, 3, 2, 1}, + extraA: nil, + extraB: nil, + }, + { + name: "integers 2", + listA: []int{1, 2, 1, 2, 1}, + listB: []int{2, 1, 2, 1, 2}, + extraA: []interface{}{1}, + extraB: []interface{}{2}, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + actualExtraA, actualExtraB := diffLists(test.listA, test.listB) + Equal(t, test.extraA, actualExtraA, "extra A does not match for listA=%v listB=%v", + test.listA, test.listB) + Equal(t, test.extraB, actualExtraB, "extra B does not match for listA=%v listB=%v", + test.listA, test.listB) + }) + } +} + +func TestNoError(t *testing.T) { + + mockT := new(testing.T) + + // start with a nil error + var err error + + True(t, NoError(mockT, err), "NoError should return True for nil arg") + + // now set an error + err = errors.New("some error") + + False(t, NoError(mockT, err), "NoError with error should return False") + + // returning an empty error interface + err = func() error { + var err *customError + return err + }() + + if err == nil { // err is not nil here! + t.Errorf("Error should be nil due to empty interface: %s", err) + } + + False(t, NoError(mockT, err), "NoError should fail with empty error interface") +} + +type customError struct{} + +func (*customError) Error() string { return "fail" } + +func TestError(t *testing.T) { + + mockT := new(testing.T) + + // start with a nil error + var err error + + False(t, Error(mockT, err), "Error should return False for nil arg") + + // now set an error + err = errors.New("some error") + + True(t, Error(mockT, err), "Error with error should return True") + + // go vet check + True(t, Errorf(mockT, err, "example with %s", "formatted message"), "Errorf with error should return True") + + // returning an empty error interface + err = func() error { + var err *customError + return err + }() + + if err == nil { // err is not nil here! + t.Errorf("Error should be nil due to empty interface: %s", err) + } + + True(t, Error(mockT, err), "Error should pass with empty error interface") +} + +func TestEqualError(t *testing.T) { + mockT := new(testing.T) + + // start with a nil error + var err error + False(t, EqualError(mockT, err, ""), + "EqualError should return false for nil arg") + + // now set an error + err = errors.New("some error") + False(t, EqualError(mockT, err, "Not some error"), + "EqualError should return false for different error string") + True(t, EqualError(mockT, err, "some error"), + "EqualError should return true") +} + +func TestErrorContains(t *testing.T) { + mockT := new(testing.T) + + // start with a nil error + var err error + False(t, ErrorContains(mockT, err, ""), + "ErrorContains should return false for nil arg") + + // now set an error + err = errors.New("some error: another error") + False(t, ErrorContains(mockT, err, "bad error"), + "ErrorContains should return false for different error string") + True(t, ErrorContains(mockT, err, "some error"), + "ErrorContains should return true") + True(t, ErrorContains(mockT, err, "another error"), + "ErrorContains should return true") +} + +func Test_isEmpty(t *testing.T) { + + chWithValue := make(chan struct{}, 1) + chWithValue <- struct{}{} + + True(t, isEmpty("")) + True(t, isEmpty(nil)) + True(t, isEmpty([]string{})) + True(t, isEmpty(0)) + True(t, isEmpty(int32(0))) + True(t, isEmpty(int64(0))) + True(t, isEmpty(false)) + True(t, isEmpty(map[string]string{})) + True(t, isEmpty(new(time.Time))) + True(t, isEmpty(time.Time{})) + True(t, isEmpty(make(chan struct{}))) + True(t, isEmpty([1]int{})) + False(t, isEmpty("something")) + False(t, isEmpty(errors.New("something"))) + False(t, isEmpty([]string{"something"})) + False(t, isEmpty(1)) + False(t, isEmpty(true)) + False(t, isEmpty(map[string]string{"Hello": "World"})) + False(t, isEmpty(chWithValue)) + False(t, isEmpty([1]int{42})) +} + +func Test_getLen(t *testing.T) { + falseCases := []interface{}{ + nil, + 0, + true, + false, + 'A', + struct{}{}, + } + for _, v := range falseCases { + ok, l := getLen(v) + False(t, ok, "Expected getLen fail to get length of %#v", v) + Equal(t, 0, l, "getLen should return 0 for %#v", v) + } + + ch := make(chan int, 5) + ch <- 1 + ch <- 2 + ch <- 3 + trueCases := []struct { + v interface{} + l int + }{ + {[]int{1, 2, 3}, 3}, + {[...]int{1, 2, 3}, 3}, + {"ABC", 3}, + {map[int]int{1: 2, 2: 4, 3: 6}, 3}, + {ch, 3}, + + {[]int{}, 0}, + {map[int]int{}, 0}, + {make(chan int), 0}, + + {[]int(nil), 0}, + {map[int]int(nil), 0}, + {(chan int)(nil), 0}, + } + + for _, c := range trueCases { + ok, l := getLen(c.v) + True(t, ok, "Expected getLen success to get length of %#v", c.v) + Equal(t, c.l, l) + } +} + +func TestLen(t *testing.T) { + mockT := new(testing.T) + + False(t, Len(mockT, nil, 0), "nil does not have length") + False(t, Len(mockT, 0, 0), "int does not have length") + False(t, Len(mockT, true, 0), "true does not have length") + False(t, Len(mockT, false, 0), "false does not have length") + False(t, Len(mockT, 'A', 0), "Rune does not have length") + False(t, Len(mockT, struct{}{}, 0), "Struct does not have length") + + ch := make(chan int, 5) + ch <- 1 + ch <- 2 + ch <- 3 + + cases := []struct { + v interface{} + l int + }{ + {[]int{1, 2, 3}, 3}, + {[...]int{1, 2, 3}, 3}, + {"ABC", 3}, + {map[int]int{1: 2, 2: 4, 3: 6}, 3}, + {ch, 3}, + + {[]int{}, 0}, + {map[int]int{}, 0}, + {make(chan int), 0}, + + {[]int(nil), 0}, + {map[int]int(nil), 0}, + {(chan int)(nil), 0}, + } + + for _, c := range cases { + True(t, Len(mockT, c.v, c.l), "%#v have %d items", c.v, c.l) + } + + cases = []struct { + v interface{} + l int + }{ + {[]int{1, 2, 3}, 4}, + {[...]int{1, 2, 3}, 2}, + {"ABC", 2}, + {map[int]int{1: 2, 2: 4, 3: 6}, 4}, + {ch, 2}, + + {[]int{}, 1}, + {map[int]int{}, 1}, + {make(chan int), 1}, + + {[]int(nil), 1}, + {map[int]int(nil), 1}, + {(chan int)(nil), 1}, + } + + for _, c := range cases { + False(t, Len(mockT, c.v, c.l), "%#v have %d items", c.v, c.l) + } +} + +func TestWithinDuration(t *testing.T) { + + mockT := new(testing.T) + a := time.Now() + b := a.Add(10 * time.Second) + + True(t, WithinDuration(mockT, a, b, 10*time.Second), "A 10s difference is within a 10s time difference") + True(t, WithinDuration(mockT, b, a, 10*time.Second), "A 10s difference is within a 10s time difference") + + False(t, WithinDuration(mockT, a, b, 9*time.Second), "A 10s difference is not within a 9s time difference") + False(t, WithinDuration(mockT, b, a, 9*time.Second), "A 10s difference is not within a 9s time difference") + + False(t, WithinDuration(mockT, a, b, -9*time.Second), "A 10s difference is not within a 9s time difference") + False(t, WithinDuration(mockT, b, a, -9*time.Second), "A 10s difference is not within a 9s time difference") + + False(t, WithinDuration(mockT, a, b, -11*time.Second), "A 10s difference is not within a 9s time difference") + False(t, WithinDuration(mockT, b, a, -11*time.Second), "A 10s difference is not within a 9s time difference") +} + +func TestInDelta(t *testing.T) { + mockT := new(testing.T) + + True(t, InDelta(mockT, 1.001, 1, 0.01), "|1.001 - 1| <= 0.01") + True(t, InDelta(mockT, 1, 1.001, 0.01), "|1 - 1.001| <= 0.01") + True(t, InDelta(mockT, 1, 2, 1), "|1 - 2| <= 1") + False(t, InDelta(mockT, 1, 2, 0.5), "Expected |1 - 2| <= 0.5 to fail") + False(t, InDelta(mockT, 2, 1, 0.5), "Expected |2 - 1| <= 0.5 to fail") + False(t, InDelta(mockT, "", nil, 1), "Expected non numerals to fail") + False(t, InDelta(mockT, 42, math.NaN(), 0.01), "Expected NaN for actual to fail") + False(t, InDelta(mockT, math.NaN(), 42, 0.01), "Expected NaN for expected to fail") + True(t, InDelta(mockT, math.NaN(), math.NaN(), 0.01), "Expected NaN for both to pass") + + cases := []struct { + a, b interface{} + delta float64 + }{ + {uint(2), uint(1), 1}, + {uint8(2), uint8(1), 1}, + {uint16(2), uint16(1), 1}, + {uint32(2), uint32(1), 1}, + {uint64(2), uint64(1), 1}, + + {int(2), int(1), 1}, + {int8(2), int8(1), 1}, + {int16(2), int16(1), 1}, + {int32(2), int32(1), 1}, + {int64(2), int64(1), 1}, + + {float32(2), float32(1), 1}, + {float64(2), float64(1), 1}, + } + + for _, tc := range cases { + True(t, InDelta(mockT, tc.a, tc.b, tc.delta), "Expected |%V - %V| <= %v", tc.a, tc.b, tc.delta) + } +} + +type diffTestingStruct struct { + A string + B int +} + +func (d *diffTestingStruct) String() string { + return d.A +} + +func TestDiff(t *testing.T) { + expected := ` + +Diff: +--- Expected ++++ Actual +@@ -1,3 +1,3 @@ + (struct { foo string }) { +- foo: (string) (len=5) "hello" ++ foo: (string) (len=3) "bar" + } +` + actual := diff( + struct{ foo string }{"hello"}, + struct{ foo string }{"bar"}, + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -2,5 +2,5 @@ + (int) 1, +- (int) 2, + (int) 3, +- (int) 4 ++ (int) 5, ++ (int) 7 + } +` + actual = diff( + []int{1, 2, 3, 4}, + []int{1, 3, 5, 7}, + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -2,4 +2,4 @@ + (int) 1, +- (int) 2, +- (int) 3 ++ (int) 3, ++ (int) 5 + } +` + actual = diff( + []int{1, 2, 3, 4}[0:3], + []int{1, 3, 5, 7}[0:3], + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -1,6 +1,6 @@ + (map[string]int) (len=4) { +- (string) (len=4) "four": (int) 4, ++ (string) (len=4) "five": (int) 5, + (string) (len=3) "one": (int) 1, +- (string) (len=5) "three": (int) 3, +- (string) (len=3) "two": (int) 2 ++ (string) (len=5) "seven": (int) 7, ++ (string) (len=5) "three": (int) 3 + } +` + + actual = diff( + map[string]int{"one": 1, "two": 2, "three": 3, "four": 4}, + map[string]int{"one": 1, "three": 3, "five": 5, "seven": 7}, + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -1,3 +1,3 @@ + (*errors.errorString)({ +- s: (string) (len=19) "some expected error" ++ s: (string) (len=12) "actual error" + }) +` + + actual = diff( + errors.New("some expected error"), + errors.New("actual error"), + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -2,3 +2,3 @@ + A: (string) (len=11) "some string", +- B: (int) 10 ++ B: (int) 15 + } +` + + actual = diff( + diffTestingStruct{A: "some string", B: 10}, + diffTestingStruct{A: "some string", B: 15}, + ) + Equal(t, expected, actual) + + expected = ` + +Diff: +--- Expected ++++ Actual +@@ -1,2 +1,2 @@ +-(time.Time) 2020-09-24 00:00:00 +0000 UTC ++(time.Time) 2020-09-25 00:00:00 +0000 UTC + +` + + actual = diff( + time.Date(2020, 9, 24, 0, 0, 0, 0, time.UTC), + time.Date(2020, 9, 25, 0, 0, 0, 0, time.UTC), + ) + Equal(t, expected, actual) +} + +func TestTimeEqualityErrorFormatting(_ *testing.T) { + mockT := new(mockTestingT) + + Equal(mockT, time.Second*2, time.Millisecond) +} + +func TestDiffEmptyCases(t *testing.T) { + Equal(t, "", diff(nil, nil)) + Equal(t, "", diff(struct{ foo string }{}, nil)) + Equal(t, "", diff(nil, struct{ foo string }{})) + Equal(t, "", diff(1, 2)) + Equal(t, "", diff(1, 2)) + Equal(t, "", diff([]int{1}, []bool{true})) +} + +// Ensure there are no data races +func TestDiffRace(t *testing.T) { + t.Parallel() + + expected := map[string]string{ + "a": "A", + "b": "B", + "c": "C", + } + + actual := map[string]string{ + "d": "D", + "e": "E", + "f": "F", + } + + // run diffs in parallel simulating tests with t.Parallel() + numRoutines := 10 + rChans := make([]chan string, numRoutines) + for idx := range rChans { + rChans[idx] = make(chan string) + go func(ch chan string) { + defer close(ch) + ch <- diff(expected, actual) + }(rChans[idx]) + } + + for _, ch := range rChans { + for msg := range ch { + NotEqual(t, msg, "") // dummy assert + } + } +} + +type mockTestingT struct { + errorFmt string + args []interface{} +} + +func (m *mockTestingT) errorString() string { + return fmt.Sprintf(m.errorFmt, m.args...) +} + +func (m *mockTestingT) Errorf(format string, args ...interface{}) { + m.errorFmt = format + m.args = args +} + +type mockFailNowTestingT struct { +} + +func (m *mockFailNowTestingT) Errorf(string, ...interface{}) {} + +func (m *mockFailNowTestingT) FailNow() {} + +func TestBytesEqual(t *testing.T) { + var cases = []struct { + a, b []byte + }{ + {make([]byte, 2), make([]byte, 2)}, + {make([]byte, 2), make([]byte, 2, 3)}, + {nil, make([]byte, 0)}, + } + for i, c := range cases { + Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1) + } +} + +func BenchmarkBytesEqual(b *testing.B) { + const size = 1024 * 8 + s := make([]byte, size) + for i := range s { + s[i] = byte(i % 255) + } + s2 := make([]byte, size) + copy(s2, s) + + mockT := &mockFailNowTestingT{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + Equal(mockT, s, s2) + } +} + +func BenchmarkNotNil(b *testing.B) { + for i := 0; i < b.N; i++ { + NotNil(b, b) + } +} + +func TestEventuallyFalse(t *testing.T) { + mockT := new(testing.T) + + condition := func() bool { + return false + } + + False(t, Eventually(mockT, condition, 100*time.Millisecond, 20*time.Millisecond)) +} + +func TestEventuallyTrue(t *testing.T) { + state := 0 + condition := func() bool { + defer func() { + state++ + }() + return state == 2 + } + + True(t, Eventually(t, condition, 100*time.Millisecond, 20*time.Millisecond)) +} + +func Test_validateEqualArgs(t *testing.T) { + if validateEqualArgs(func() {}, func() {}) == nil { + t.Error("non-nil functions should error") + } + + if validateEqualArgs(func() {}, func() {}) == nil { + t.Error("non-nil functions should error") + } + + if validateEqualArgs(nil, nil) != nil { + t.Error("nil functions are equal") + } +} + +func Test_truncatingFormat(t *testing.T) { + + original := strings.Repeat("a", bufio.MaxScanTokenSize-102) + result := truncatingFormat(original) + Equal(t, fmt.Sprintf("%#v", original), result, "string should not be truncated") + + original = original + "x" + result = truncatingFormat(original) + NotEqual(t, fmt.Sprintf("%#v", original), result, "string should have been truncated.") + + if !strings.HasSuffix(result, "<... truncated>") { + t.Error("truncated string should have <... truncated> suffix") + } +} diff --git a/drivers/mongov2/internal/assert/difflib.go b/drivers/mongov2/internal/assert/difflib.go new file mode 100644 index 0000000..e13a66a --- /dev/null +++ b/drivers/mongov2/internal/assert/difflib.go @@ -0,0 +1,766 @@ +// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib.go + +// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is +// governed by a license that can be found in the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func calculateRatio(matches, length int) float64 { + if length > 0 { + return 2.0 * float64(matches) / float64(length) + } + return 1.0 +} + +type Match struct { + A int + B int + Size int +} + +type OpCode struct { + Tag byte + I1 int + I2 int + J1 int + J2 int +} + +// SequenceMatcher compares sequence of strings. The basic +// algorithm predates, and is a little fancier than, an algorithm +// published in the late 1980's by Ratcliff and Obershelp under the +// hyperbolic name "gestalt pattern matching". The basic idea is to find +// the longest contiguous matching subsequence that contains no "junk" +// elements (R-O doesn't address junk). The same idea is then applied +// recursively to the pieces of the sequences to the left and to the right +// of the matching subsequence. This does not yield minimal edit +// sequences, but does tend to yield matches that "look right" to people. +// +// SequenceMatcher tries to compute a "human-friendly diff" between two +// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the +// longest *contiguous* & junk-free matching subsequence. That's what +// catches peoples' eyes. The Windows(tm) windiff has another interesting +// notion, pairing up elements that appear uniquely in each sequence. +// That, and the method here, appear to yield more intuitive difference +// reports than does diff. This method appears to be the least vulnerable +// to syncing up on blocks of "junk lines", though (like blank lines in +// ordinary text files, or maybe "

" lines in HTML files). That may be +// because this is the only method of the 3 that has a *concept* of +// "junk" . +// +// Timing: Basic R-O is cubic time worst case and quadratic time expected +// case. SequenceMatcher is quadratic time for the worst case and has +// expected-case behavior dependent in a complicated way on how many +// elements the sequences have in common; best case time is linear. +type SequenceMatcher struct { + a []string + b []string + b2j map[string][]int + IsJunk func(string) bool + autoJunk bool + bJunk map[string]struct{} + matchingBlocks []Match + fullBCount map[string]int + bPopular map[string]struct{} + opCodes []OpCode +} + +func NewMatcher(a, b []string) *SequenceMatcher { + m := SequenceMatcher{autoJunk: true} + m.SetSeqs(a, b) + return &m +} + +func NewMatcherWithJunk(a, b []string, autoJunk bool, + isJunk func(string) bool) *SequenceMatcher { + + m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk} + m.SetSeqs(a, b) + return &m +} + +// SetSeqs sets the two sequences to be compared. +func (m *SequenceMatcher) SetSeqs(a, b []string) { + m.SetSeq1(a) + m.SetSeq2(b) +} + +// SetSeq1 sets the first sequence to be compared. The second sequence to be compared is +// not changed. +// +// SequenceMatcher computes and caches detailed information about the second +// sequence, so if you want to compare one sequence S against many sequences, +// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other +// sequences. +// +// See also SetSeqs() and SetSeq2(). +func (m *SequenceMatcher) SetSeq1(a []string) { + if &a == &m.a { + return + } + m.a = a + m.matchingBlocks = nil + m.opCodes = nil +} + +// SetSeq2 sets the second sequence to be compared. The first sequence to be compared is +// not changed. +func (m *SequenceMatcher) SetSeq2(b []string) { + if &b == &m.b { + return + } + m.b = b + m.matchingBlocks = nil + m.opCodes = nil + m.fullBCount = nil + m.chainB() +} + +func (m *SequenceMatcher) chainB() { + // Populate line -> index mapping + b2j := map[string][]int{} + for i, s := range m.b { + indices := b2j[s] + indices = append(indices, i) + b2j[s] = indices + } + + // Purge junk elements + m.bJunk = map[string]struct{}{} + if m.IsJunk != nil { + junk := m.bJunk + for s := range b2j { + if m.IsJunk(s) { + junk[s] = struct{}{} + } + } + for s := range junk { + delete(b2j, s) + } + } + + // Purge remaining popular elements + popular := map[string]struct{}{} + n := len(m.b) + if m.autoJunk && n >= 200 { + ntest := n/100 + 1 + for s, indices := range b2j { + if len(indices) > ntest { + popular[s] = struct{}{} + } + } + for s := range popular { + delete(b2j, s) + } + } + m.bPopular = popular + m.b2j = b2j +} + +func (m *SequenceMatcher) isBJunk(s string) bool { + _, ok := m.bJunk[s] + return ok +} + +// Find longest matching block in a[alo:ahi] and b[blo:bhi]. +// +// If IsJunk is not defined: +// +// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where +// +// alo <= i <= i+k <= ahi +// blo <= j <= j+k <= bhi +// +// and for all (i',j',k') meeting those conditions, +// +// k >= k' +// i <= i' +// and if i == i', j <= j' +// +// In other words, of all maximal matching blocks, return one that +// starts earliest in a, and of all those maximal matching blocks that +// start earliest in a, return the one that starts earliest in b. +// +// If IsJunk is defined, first the longest matching block is +// determined as above, but with the additional restriction that no +// junk element appears in the block. Then that block is extended as +// far as possible by matching (only) junk elements on both sides. So +// the resulting block never matches on junk except as identical junk +// happens to be adjacent to an "interesting" match. +// +// If no blocks match, return (alo, blo, 0). +func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { + // CAUTION: stripping common prefix or suffix would be incorrect. + // E.g., + // ab + // acab + // Longest matching block is "ab", but if common prefix is + // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so + // strip, so ends up claiming that ab is changed to acab by + // inserting "ca" in the middle. That's minimal but unintuitive: + // "it's obvious" that someone inserted "ac" at the front. + // Windiff ends up at the same place as diff, but by pairing up + // the unique 'b's and then matching the first two 'a's. + besti, bestj, bestsize := alo, blo, 0 + + // find longest junk-free match + // during an iteration of the loop, j2len[j] = length of longest + // junk-free match ending with a[i-1] and b[j] + j2len := map[int]int{} + for i := alo; i != ahi; i++ { + // look at all instances of a[i] in b; note that because + // b2j has no junk keys, the loop is skipped if a[i] is junk + newj2len := map[int]int{} + for _, j := range m.b2j[m.a[i]] { + // a[i] matches b[j] + if j < blo { + continue + } + if j >= bhi { + break + } + k := j2len[j-1] + 1 + newj2len[j] = k + if k > bestsize { + besti, bestj, bestsize = i-k+1, j-k+1, k + } + } + j2len = newj2len + } + + // Extend the best by non-junk elements on each end. In particular, + // "popular" non-junk elements aren't in b2j, which greatly speeds + // the inner loop above, but also means "the best" match so far + // doesn't contain any junk *or* popular non-junk elements. + for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + !m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize++ + } + + // Now that we have a wholly interesting match (albeit possibly + // empty!), we may as well suck up the matching junk on each + // side of it too. Can't think of a good reason not to, and it + // saves post-processing the (possibly considerable) expense of + // figuring out what to do with it. In the case of an empty + // interesting match, this is clearly the right thing to do, + // because no other kind of match is possible in the regions. + for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize++ + } + + return Match{A: besti, B: bestj, Size: bestsize} +} + +// GetMatchingBlocks returns list of triples describing matching subsequences. +// +// Each triple is of the form (i, j, n), and means that +// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in +// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are +// adjacent triples in the list, and the second is not the last triple in the +// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe +// adjacent equal blocks. +// +// The last triple is a dummy, (len(a), len(b), 0), and is the only +// triple with n==0. +func (m *SequenceMatcher) GetMatchingBlocks() []Match { + if m.matchingBlocks != nil { + return m.matchingBlocks + } + + var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match + matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { + match := m.findLongestMatch(alo, ahi, blo, bhi) + i, j, k := match.A, match.B, match.Size + if match.Size > 0 { + if alo < i && blo < j { + matched = matchBlocks(alo, i, blo, j, matched) + } + matched = append(matched, match) + if i+k < ahi && j+k < bhi { + matched = matchBlocks(i+k, ahi, j+k, bhi, matched) + } + } + return matched + } + matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) + + // It's possible that we have adjacent equal blocks in the + // matching_blocks list now. + nonAdjacent := []Match{} + i1, j1, k1 := 0, 0, 0 + for _, b := range matched { + // Is this block adjacent to i1, j1, k1? + i2, j2, k2 := b.A, b.B, b.Size + if i1+k1 == i2 && j1+k1 == j2 { + // Yes, so collapse them -- this just increases the length of + // the first block by the length of the second, and the first + // block so lengthened remains the block to compare against. + k1 += k2 + } else { + // Not adjacent. Remember the first block (k1==0 means it's + // the dummy we started with), and make the second block the + // new block to compare against. + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + i1, j1, k1 = i2, j2, k2 + } + } + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + + nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) + m.matchingBlocks = nonAdjacent + return m.matchingBlocks +} + +// GetOpCodes returns a list of 5-tuples describing how to turn a into b. +// +// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple +// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the +// tuple preceding it, and likewise for j1 == the previous j2. +// +// The tags are characters, with these meanings: +// +// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] +// +// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. +// +// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. +// +// 'e' (equal): a[i1:i2] == b[j1:j2] +func (m *SequenceMatcher) GetOpCodes() []OpCode { + if m.opCodes != nil { + return m.opCodes + } + i, j := 0, 0 + matching := m.GetMatchingBlocks() + opCodes := make([]OpCode, 0, len(matching)) + for _, m := range matching { + // invariant: we've pumped out correct diffs to change + // a[:i] into b[:j], and the next matching block is + // a[ai:ai+size] == b[bj:bj+size]. So we need to pump + // out a diff to change a[i:ai] into b[j:bj], pump out + // the matching block, and move (i,j) beyond the match + ai, bj, size := m.A, m.B, m.Size + tag := byte(0) + if i < ai && j < bj { + tag = 'r' + } else if i < ai { + tag = 'd' + } else if j < bj { + tag = 'i' + } + if tag > 0 { + opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) + } + i, j = ai+size, bj+size + // the list of matching blocks is terminated by a + // sentinel with size 0 + if size > 0 { + opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) + } + } + m.opCodes = opCodes + return m.opCodes +} + +// GetGroupedOpCodes isolates change clusters by eliminating ranges with no changes. +// +// Returns a generator of groups with up to n lines of context. +// Each group is in the same format as returned by GetOpCodes(). +func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { + if n < 0 { + n = 3 + } + codes := m.GetOpCodes() + if len(codes) == 0 { + codes = []OpCode{{'e', 0, 1, 0, 1}} + } + // Fixup leading and trailing groups if they show no changes. + if codes[0].Tag == 'e' { + c := codes[0] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} + } + if codes[len(codes)-1].Tag == 'e' { + c := codes[len(codes)-1] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} + } + nn := n + n + groups := [][]OpCode{} + group := []OpCode{} + for _, c := range codes { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + // End the current group and start a new one whenever + // there is a large range with no changes. + if c.Tag == 'e' && i2-i1 > nn { + group = append(group, OpCode{c.Tag, i1, min(i2, i1+n), + j1, min(j2, j1+n)}) + groups = append(groups, group) + group = []OpCode{} + i1, j1 = max(i1, i2-n), max(j1, j2-n) + } + group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) + } + if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { + groups = append(groups, group) + } + return groups +} + +// Ratio returns a measure of the sequences' similarity (float in [0,1]). +// +// Where T is the total number of elements in both sequences, and +// M is the number of matches, this is 2.0*M / T. +// Note that this is 1 if the sequences are identical, and 0 if +// they have nothing in common. +// +// .Ratio() is expensive to compute if you haven't already computed +// .GetMatchingBlocks() or .GetOpCodes(), in which case you may +// want to try .QuickRatio() or .RealQuickRation() first to get an +// upper bound. +func (m *SequenceMatcher) Ratio() float64 { + matches := 0 + for _, m := range m.GetMatchingBlocks() { + matches += m.Size + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// QuickRatio returns an upper bound on ratio() relatively quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute. +func (m *SequenceMatcher) QuickRatio() float64 { + // viewing a and b as multisets, set matches to the cardinality + // of their intersection; this counts the number of matches + // without regard to order, so is clearly an upper bound + if m.fullBCount == nil { + m.fullBCount = map[string]int{} + for _, s := range m.b { + m.fullBCount[s] = m.fullBCount[s] + 1 + } + } + + // avail[x] is the number of times x appears in 'b' less the + // number of times we've seen it in 'a' so far ... kinda + avail := map[string]int{} + matches := 0 + for _, s := range m.a { + n, ok := avail[s] + if !ok { + n = m.fullBCount[s] + } + avail[s] = n - 1 + if n > 0 { + matches++ + } + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// RealQuickRatio returns an upper bound on ratio() very quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute than either .Ratio() or .QuickRatio(). +func (m *SequenceMatcher) RealQuickRatio() float64 { + la, lb := len(m.a), len(m.b) + return calculateRatio(min(la, lb), la+lb) +} + +// Convert range to the "ed" format +func formatRangeUnified(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 1 { + return fmt.Sprintf("%d", beginning) + } + if length == 0 { + beginning-- // empty ranges begin at line just before the range + } + return fmt.Sprintf("%d,%d", beginning, length) +} + +// UnifiedDiff represents the unified diff parameters. +type UnifiedDiff struct { + A []string // First sequence lines + FromFile string // First file name + FromDate string // First file time + B []string // Second sequence lines + ToFile string // Second file name + ToDate string // Second file time + Eol string // Headers end of line, defaults to LF + Context int // Number of context lines +} + +// WriteUnifiedDiff compares two sequences of lines; generates the delta as +// a unified diff. +// +// Unified diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by 'n' which +// defaults to three. +// +// By default, the diff control lines (those with ---, +++, or @@) are +// created with a trailing newline. This is helpful so that inputs +// created from file.readlines() result in diffs that are suitable for +// file.writelines() since both the inputs and outputs have trailing +// newlines. +// +// For inputs that do not have trailing newlines, set the lineterm +// argument to "" so that the output will be uniformly newline free. +// +// The unidiff format normally has a header for filenames and modification +// times. Any or all of these may be specified using strings for +// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. +// The modification times are normally expressed in the ISO 8601 format. +func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + wf := func(format string, args ...interface{}) error { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + return err + } + ws := func(s string) error { + _, err := buf.WriteString(s) + return err + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol) + if err != nil { + return err + } + err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol) + if err != nil { + return err + } + } + } + first, last := g[0], g[len(g)-1] + range1 := formatRangeUnified(first.I1, last.I2) + range2 := formatRangeUnified(first.J1, last.J2) + if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil { + return err + } + for _, c := range g { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + if c.Tag == 'e' { + for _, line := range diff.A[i1:i2] { + if err := ws(" " + line); err != nil { + return err + } + } + continue + } + if c.Tag == 'r' || c.Tag == 'd' { + for _, line := range diff.A[i1:i2] { + if err := ws("-" + line); err != nil { + return err + } + } + } + if c.Tag == 'r' || c.Tag == 'i' { + for _, line := range diff.B[j1:j2] { + if err := ws("+" + line); err != nil { + return err + } + } + } + } + } + return nil +} + +// GetUnifiedDiffString is like WriteUnifiedDiff but returns the diff as a string. +func GetUnifiedDiffString(diff UnifiedDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteUnifiedDiff(w, diff) + return w.String(), err +} + +// Convert range to the "ed" format. +func formatRangeContext(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 0 { + beginning-- // empty ranges begin at line just before the range + } + if length <= 1 { + return fmt.Sprintf("%d", beginning) + } + return fmt.Sprintf("%d,%d", beginning, beginning+length-1) +} + +type ContextDiff UnifiedDiff + +// WriteContextDiff compares two sequences of lines; generates the delta as a context diff. +// +// Context diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by diff.Context +// which defaults to three. +// +// By default, the diff control lines (those with *** or ---) are +// created with a trailing newline. +// +// For inputs that do not have trailing newlines, set the diff.Eol +// argument to "" so that the output will be uniformly newline free. +// +// The context diff format normally has a header for filenames and +// modification times. Any or all of these may be specified using +// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate. +// The modification times are normally expressed in the ISO 8601 format. +// If not specified, the strings default to blanks. +func WriteContextDiff(writer io.Writer, diff ContextDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + var diffErr error + wf := func(format string, args ...interface{}) { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + if diffErr == nil && err != nil { + diffErr = err + } + } + ws := func(s string) { + _, err := buf.WriteString(s) + if diffErr == nil && err != nil { + diffErr = err + } + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + prefix := map[byte]string{ + 'i': "+ ", + 'd': "- ", + 'r': "! ", + 'e': " ", + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol) + wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol) + } + } + + first, last := g[0], g[len(g)-1] + ws("***************" + diff.Eol) + + range1 := formatRangeContext(first.I1, last.I2) + wf("*** %s ****%s", range1, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'd' { + for _, cc := range g { + if cc.Tag == 'i' { + continue + } + for _, line := range diff.A[cc.I1:cc.I2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + + range2 := formatRangeContext(first.J1, last.J2) + wf("--- %s ----%s", range2, diff.Eol) + for _, c := range g { + if c.Tag == 'r' || c.Tag == 'i' { + for _, cc := range g { + if cc.Tag == 'd' { + continue + } + for _, line := range diff.B[cc.J1:cc.J2] { + ws(prefix[cc.Tag] + line) + } + } + break + } + } + } + return diffErr +} + +// GetContextDiffString is like WriteContextDiff but returns the diff as a string. +func GetContextDiffString(diff ContextDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteContextDiff(w, diff) + return w.String(), err +} + +// SplitLines splits a string on "\n" while preserving them. The output can be used +// as input for UnifiedDiff and ContextDiff structures. +func SplitLines(s string) []string { + lines := strings.SplitAfter(s, "\n") + lines[len(lines)-1] += "\n" + return lines +} diff --git a/drivers/mongov2/internal/assert/difflib_test.go b/drivers/mongov2/internal/assert/difflib_test.go new file mode 100644 index 0000000..b310a92 --- /dev/null +++ b/drivers/mongov2/internal/assert/difflib_test.go @@ -0,0 +1,326 @@ +// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib_test.go + +// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is +// governed by a license that can be found in the THIRD-PARTY-NOTICES file. + +package assert + +import ( + "bytes" + "fmt" + "math" + "reflect" + "strings" + "testing" +) + +func assertAlmostEqual(t *testing.T, a, b float64, places int) { + if math.Abs(a-b) > math.Pow10(-places) { + t.Errorf("%.7f != %.7f", a, b) + } +} + +func assertEqual(t *testing.T, a, b interface{}) { + if !reflect.DeepEqual(a, b) { + t.Errorf("%v != %v", a, b) + } +} + +func splitChars(s string) []string { + chars := make([]string, 0, len(s)) + // Assume ASCII inputs + for i := 0; i != len(s); i++ { + chars = append(chars, string(s[i])) + } + return chars +} + +func TestSequenceMatcherRatio(t *testing.T) { + s := NewMatcher(splitChars("abcd"), splitChars("bcde")) + assertEqual(t, s.Ratio(), 0.75) + assertEqual(t, s.QuickRatio(), 0.75) + assertEqual(t, s.RealQuickRatio(), 1.0) +} + +func TestGetOptCodes(t *testing.T) { + a := "qabxcd" + b := "abycdf" + s := NewMatcher(splitChars(a), splitChars(b)) + w := &bytes.Buffer{} + for _, op := range s.GetOpCodes() { + fmt.Fprintf(w, "%s a[%d:%d], (%s) b[%d:%d] (%s)\n", string(op.Tag), + op.I1, op.I2, a[op.I1:op.I2], op.J1, op.J2, b[op.J1:op.J2]) + } + result := w.String() + expected := `d a[0:1], (q) b[0:0] () +e a[1:3], (ab) b[0:2] (ab) +r a[3:4], (x) b[2:3] (y) +e a[4:6], (cd) b[3:5] (cd) +i a[6:6], () b[5:6] (f) +` + if expected != result { + t.Errorf("unexpected op codes: \n%s", result) + } +} + +func TestGroupedOpCodes(t *testing.T) { + a := []string{} + for i := 0; i != 39; i++ { + a = append(a, fmt.Sprintf("%02d", i)) + } + b := []string{} + b = append(b, a[:8]...) + b = append(b, " i") + b = append(b, a[8:19]...) + b = append(b, " x") + b = append(b, a[20:22]...) + b = append(b, a[27:34]...) + b = append(b, " y") + b = append(b, a[35:]...) + s := NewMatcher(a, b) + w := &bytes.Buffer{} + for _, g := range s.GetGroupedOpCodes(-1) { + fmt.Fprintf(w, "group\n") + for _, op := range g { + fmt.Fprintf(w, " %s, %d, %d, %d, %d\n", string(op.Tag), + op.I1, op.I2, op.J1, op.J2) + } + } + result := w.String() + expected := `group + e, 5, 8, 5, 8 + i, 8, 8, 8, 9 + e, 8, 11, 9, 12 +group + e, 16, 19, 17, 20 + r, 19, 20, 20, 21 + e, 20, 22, 21, 23 + d, 22, 27, 23, 23 + e, 27, 30, 23, 26 +group + e, 31, 34, 27, 30 + r, 34, 35, 30, 31 + e, 35, 38, 31, 34 +` + if expected != result { + t.Errorf("unexpected op codes: \n%s", result) + } +} + +func rep(s string, count int) string { + return strings.Repeat(s, count) +} + +func TestWithAsciiOneInsert(t *testing.T) { + sm := NewMatcher(splitChars(rep("b", 100)), + splitChars("a"+rep("b", 100))) + assertAlmostEqual(t, sm.Ratio(), 0.995, 3) + assertEqual(t, sm.GetOpCodes(), + []OpCode{{'i', 0, 0, 0, 1}, {'e', 0, 100, 1, 101}}) + assertEqual(t, len(sm.bPopular), 0) + + sm = NewMatcher(splitChars(rep("b", 100)), + splitChars(rep("b", 50)+"a"+rep("b", 50))) + assertAlmostEqual(t, sm.Ratio(), 0.995, 3) + assertEqual(t, sm.GetOpCodes(), + []OpCode{{'e', 0, 50, 0, 50}, {'i', 50, 50, 50, 51}, {'e', 50, 100, 51, 101}}) + assertEqual(t, len(sm.bPopular), 0) +} + +func TestWithAsciiOnDelete(t *testing.T) { + sm := NewMatcher(splitChars(rep("a", 40)+"c"+rep("b", 40)), + splitChars(rep("a", 40)+rep("b", 40))) + assertAlmostEqual(t, sm.Ratio(), 0.994, 3) + assertEqual(t, sm.GetOpCodes(), + []OpCode{{'e', 0, 40, 0, 40}, {'d', 40, 41, 40, 40}, {'e', 41, 81, 40, 80}}) +} + +func TestWithAsciiBJunk(t *testing.T) { + isJunk := func(s string) bool { + return s == " " + } + sm := NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)), + splitChars(rep("a", 44)+rep("b", 40)), true, isJunk) + assertEqual(t, sm.bJunk, map[string]struct{}{}) + + sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)), + splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk) + assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}}) + + isJunk = func(s string) bool { + return s == " " || s == "b" + } + sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)), + splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk) + assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}, "b": struct{}{}}) +} + +func TestSFBugsRatioForNullSeqn(t *testing.T) { + sm := NewMatcher(nil, nil) + assertEqual(t, sm.Ratio(), 1.0) + assertEqual(t, sm.QuickRatio(), 1.0) + assertEqual(t, sm.RealQuickRatio(), 1.0) +} + +func TestSFBugsComparingEmptyLists(t *testing.T) { + groups := NewMatcher(nil, nil).GetGroupedOpCodes(-1) + assertEqual(t, len(groups), 0) + diff := UnifiedDiff{ + FromFile: "Original", + ToFile: "Current", + Context: 3, + } + result, err := GetUnifiedDiffString(diff) + assertEqual(t, err, nil) + assertEqual(t, result, "") +} + +func TestOutputFormatRangeFormatUnified(t *testing.T) { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + // + // Each field shall be of the form: + // %1d", if the range contains exactly one line, + // and: + // "%1d,%1d", , otherwise. + // If a range is empty, its beginning line number shall be the number of + // the line just before the range, or 0 if the empty range starts the file. + fm := formatRangeUnified + assertEqual(t, fm(3, 3), "3,0") + assertEqual(t, fm(3, 4), "4") + assertEqual(t, fm(3, 5), "4,2") + assertEqual(t, fm(3, 6), "4,3") + assertEqual(t, fm(0, 0), "0,0") +} + +func TestOutputFormatRangeFormatContext(t *testing.T) { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + // + // The range of lines in file1 shall be written in the following format + // if the range contains two or more lines: + // "*** %d,%d ****\n", , + // and the following format otherwise: + // "*** %d ****\n", + // The ending line number of an empty range shall be the number of the preceding line, + // or 0 if the range is at the start of the file. + // + // Next, the range of lines in file2 shall be written in the following format + // if the range contains two or more lines: + // "--- %d,%d ----\n", , + // and the following format otherwise: + // "--- %d ----\n", + fm := formatRangeContext + assertEqual(t, fm(3, 3), "3") + assertEqual(t, fm(3, 4), "4") + assertEqual(t, fm(3, 5), "4,5") + assertEqual(t, fm(3, 6), "4,6") + assertEqual(t, fm(0, 0), "0") +} + +func TestOutputFormatTabDelimiter(t *testing.T) { + diff := UnifiedDiff{ + A: splitChars("one"), + B: splitChars("two"), + FromFile: "Original", + FromDate: "2005-01-26 23:30:50", + ToFile: "Current", + ToDate: "2010-04-12 10:20:52", + Eol: "\n", + } + ud, err := GetUnifiedDiffString(diff) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(ud)[:2], []string{ + "--- Original\t2005-01-26 23:30:50\n", + "+++ Current\t2010-04-12 10:20:52\n", + }) + cd, err := GetContextDiffString(ContextDiff(diff)) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(cd)[:2], []string{ + "*** Original\t2005-01-26 23:30:50\n", + "--- Current\t2010-04-12 10:20:52\n", + }) +} + +func TestOutputFormatNoTrailingTabOnEmptyFiledate(t *testing.T) { + diff := UnifiedDiff{ + A: splitChars("one"), + B: splitChars("two"), + FromFile: "Original", + ToFile: "Current", + Eol: "\n", + } + ud, err := GetUnifiedDiffString(diff) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(ud)[:2], []string{"--- Original\n", "+++ Current\n"}) + + cd, err := GetContextDiffString(ContextDiff(diff)) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(cd)[:2], []string{"*** Original\n", "--- Current\n"}) +} + +func TestOmitFilenames(t *testing.T) { + diff := UnifiedDiff{ + A: SplitLines("o\nn\ne\n"), + B: SplitLines("t\nw\no\n"), + Eol: "\n", + } + ud, err := GetUnifiedDiffString(diff) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(ud), []string{ + "@@ -0,0 +1,2 @@\n", + "+t\n", + "+w\n", + "@@ -2,2 +3,0 @@\n", + "-n\n", + "-e\n", + "\n", + }) + + cd, err := GetContextDiffString(ContextDiff(diff)) + assertEqual(t, err, nil) + assertEqual(t, SplitLines(cd), []string{ + "***************\n", + "*** 0 ****\n", + "--- 1,2 ----\n", + "+ t\n", + "+ w\n", + "***************\n", + "*** 2,3 ****\n", + "- n\n", + "- e\n", + "--- 3 ----\n", + "\n", + }) +} + +func TestSplitLines(t *testing.T) { + allTests := []struct { + input string + want []string + }{ + {"foo", []string{"foo\n"}}, + {"foo\nbar", []string{"foo\n", "bar\n"}}, + {"foo\nbar\n", []string{"foo\n", "bar\n", "\n"}}, + } + for _, test := range allTests { + assertEqual(t, SplitLines(test.input), test.want) + } +} + +func benchmarkSplitLines(b *testing.B, count int) { + str := strings.Repeat("foo\n", count) + + b.ResetTimer() + + n := 0 + for i := 0; i < b.N; i++ { + n += len(SplitLines(str)) + } +} + +func BenchmarkSplitLines100(b *testing.B) { + benchmarkSplitLines(b, 100) +} + +func BenchmarkSplitLines10000(b *testing.B) { + benchmarkSplitLines(b, 10000) +} diff --git a/drivers/mongov2/internal/csfle/csfle.go b/drivers/mongov2/internal/csfle/csfle.go new file mode 100644 index 0000000..7a91045 --- /dev/null +++ b/drivers/mongov2/internal/csfle/csfle.go @@ -0,0 +1,40 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package csfle + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +const ( + EncryptedCacheCollection = "ecc" + EncryptedStateCollection = "esc" + EncryptedCompactionCollection = "ecoc" +) + +// GetEncryptedStateCollectionName returns the encrypted state collection name associated with dataCollectionName. +func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionName string, stateCollection string) (string, error) { + fieldName := stateCollection + "Collection" + val, err := efBSON.LookupErr(fieldName) + if err != nil { + if !errors.Is(err, bsoncore.ErrElementNotFound) { + return "", err + } + // Return default name. + defaultName := "enxcol_." + dataCollectionName + "." + stateCollection + return defaultName, nil + } + + stateCollectionName, ok := val.StringValueOK() + if !ok { + return "", fmt.Errorf("expected string for '%v', got: %v", fieldName, val.Type) + } + return stateCollectionName, nil +} diff --git a/drivers/mongov2/internal/driverutil/description.go b/drivers/mongov2/internal/driverutil/description.go new file mode 100644 index 0000000..df3adc3 --- /dev/null +++ b/drivers/mongov2/internal/driverutil/description.go @@ -0,0 +1,493 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "errors" + "fmt" + "time" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/bsonutil" + "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/internal/ptrutil" + "go.mongodb.org/mongo-driver/v2/mongo/address" + "go.mongodb.org/mongo-driver/v2/tag" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" +) + +const ( + MinWireVersion = 6 + MaxWireVersion = 25 +) + +func equalWireVersion(wv1, wv2 *description.VersionRange) bool { + if wv1 == nil && wv2 == nil { + return true + } + + if wv1 == nil || wv2 == nil { + return false + } + + return wv1.Min == wv2.Min && wv1.Max == wv2.Max +} + +// EqualServers compares two server descriptions and returns true if they are +// equal. +func EqualServers(srv1, srv2 description.Server) bool { + if srv1.CanonicalAddr.String() != srv2.CanonicalAddr.String() { + return false + } + + if !sliceStringEqual(srv1.Arbiters, srv2.Arbiters) { + return false + } + + if !sliceStringEqual(srv1.Hosts, srv2.Hosts) { + return false + } + + if !sliceStringEqual(srv1.Passives, srv2.Passives) { + return false + } + + if srv1.Primary != srv2.Primary { + return false + } + + if srv1.SetName != srv2.SetName { + return false + } + + if srv1.Kind != srv2.Kind { + return false + } + + if srv1.LastError != nil || srv2.LastError != nil { + if srv1.LastError == nil || srv2.LastError == nil { + return false + } + if srv1.LastError.Error() != srv2.LastError.Error() { + return false + } + } + + if !equalWireVersion(srv1.WireVersion, srv2.WireVersion) { + return false + } + + if len(srv1.Tags) != len(srv2.Tags) || !srv1.Tags.ContainsAll(srv2.Tags) { + return false + } + + if srv1.SetVersion != srv2.SetVersion { + return false + } + + if srv1.ElectionID != srv2.ElectionID { + return false + } + + if ptrutil.CompareInt64(srv1.SessionTimeoutMinutes, srv2.SessionTimeoutMinutes) != 0 { + return false + } + + // If TopologyVersion is nil for both servers, CompareToIncoming will return -1 because it assumes that the + // incoming response is newer. We want the descriptions to be considered equal in this case, though, so an + // explicit check is required. + if srv1.TopologyVersion == nil && srv2.TopologyVersion == nil { + return true + } + + return CompareTopologyVersions(srv1.TopologyVersion, srv2.TopologyVersion) == 0 +} + +// IsServerLoadBalanced checks if a description.Server describes a server that +// is load balanced. +func IsServerLoadBalanced(srv description.Server) bool { + return srv.Kind == description.ServerKindLoadBalancer || srv.ServiceID != nil +} + +// stringSliceFromRawElement decodes the provided BSON element into a []string. +// This internally calls StringSliceFromRawValue on the element's value. The +// error conditions outlined in that function's documentation apply for this +// function as well. +func stringSliceFromRawElement(element bson.RawElement) ([]string, error) { + return bsonutil.StringSliceFromRawValue(element.Key(), element.Value()) +} + +func decodeStringMap(element bson.RawElement, name string) (map[string]string, error) { + doc, ok := element.Value().DocumentOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be a document but it's a BSON %s", name, element.Value().Type) + } + elements, err := doc.Elements() + if err != nil { + return nil, err + } + m := make(map[string]string) + for _, element := range elements { + key := element.Key() + value, ok := element.Value().StringValueOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be a document of strings, but found a BSON %s", name, element.Value().Type) + } + m[key] = value + } + return m, nil +} + +// NewTopologyVersion creates a TopologyVersion based on doc +func NewTopologyVersion(doc bson.Raw) (*description.TopologyVersion, error) { + elements, err := doc.Elements() + if err != nil { + return nil, err + } + var tv description.TopologyVersion + var ok bool + for _, element := range elements { + switch element.Key() { + case "processId": + tv.ProcessID, ok = element.Value().ObjectIDOK() + if !ok { + return nil, fmt.Errorf("expected 'processId' to be a objectID but it's a BSON %s", element.Value().Type) + } + case "counter": + tv.Counter, ok = element.Value().Int64OK() + if !ok { + return nil, fmt.Errorf("expected 'counter' to be an int64 but it's a BSON %s", element.Value().Type) + } + } + } + return &tv, nil +} + +// NewVersionRange creates a new VersionRange given a min and a max. +func NewVersionRange(min, max int32) description.VersionRange { + return description.VersionRange{Min: min, Max: max} +} + +// VersionRangeIncludes returns a bool indicating whether the supplied integer +// is included in the range. +func VersionRangeIncludes(versionRange description.VersionRange, v int32) bool { + return v >= versionRange.Min && v <= versionRange.Max +} + +// CompareTopologyVersions compares the receiver, which represents the currently +// known TopologyVersion for a server, to an incoming TopologyVersion extracted +// from a server command response. +// +// This returns -1 if the receiver version is less than the response, 0 if the +// versions are equal, and 1 if the receiver version is greater than the +// response. This comparison is not commutative. +func CompareTopologyVersions(receiver, response *description.TopologyVersion) int { + if receiver == nil || response == nil { + return -1 + } + if receiver.ProcessID != response.ProcessID { + return -1 + } + if receiver.Counter == response.Counter { + return 0 + } + if receiver.Counter < response.Counter { + return -1 + } + return 1 +} + +// NewServerDescription creates a new server description from the given hello +// command response. +func NewServerDescription(addr address.Address, response bson.Raw) description.Server { + desc := description.Server{Addr: addr, CanonicalAddr: addr, LastUpdateTime: time.Now().UTC()} + elements, err := response.Elements() + if err != nil { + desc.LastError = err + return desc + } + var ok bool + var isReplicaSet, isWritablePrimary, hidden, secondary, arbiterOnly bool + var msg string + var versionRange description.VersionRange + for _, element := range elements { + switch element.Key() { + case "arbiters": + var err error + desc.Arbiters, err = stringSliceFromRawElement(element) + if err != nil { + desc.LastError = err + return desc + } + case "arbiterOnly": + arbiterOnly, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'arbiterOnly' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "compression": + var err error + desc.Compression, err = stringSliceFromRawElement(element) + if err != nil { + desc.LastError = err + return desc + } + case "electionId": + desc.ElectionID, ok = element.Value().ObjectIDOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'electionId' to be a objectID but it's a BSON %s", element.Value().Type) + return desc + } + case "iscryptd": + desc.IsCryptd, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'iscryptd' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "helloOk": + desc.HelloOK, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'helloOk' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "hidden": + hidden, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'hidden' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "hosts": + var err error + desc.Hosts, err = stringSliceFromRawElement(element) + if err != nil { + desc.LastError = err + return desc + } + case "isWritablePrimary": + isWritablePrimary, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'isWritablePrimary' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case handshake.LegacyHelloLowercase: + isWritablePrimary, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected legacy hello to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "isreplicaset": + isReplicaSet, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'isreplicaset' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "lastWrite": + lastWrite, ok := element.Value().DocumentOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'lastWrite' to be a document but it's a BSON %s", element.Value().Type) + return desc + } + dateTime, err := lastWrite.LookupErr("lastWriteDate") + if err == nil { + dt, ok := dateTime.DateTimeOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'lastWriteDate' to be a datetime but it's a BSON %s", dateTime.Type) + return desc + } + desc.LastWriteTime = time.Unix(dt/1000, dt%1000*1000000).UTC() + } + case "logicalSessionTimeoutMinutes": + i64, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'logicalSessionTimeoutMinutes' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + + desc.SessionTimeoutMinutes = &i64 + case "maxBsonObjectSize": + i64, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'maxBsonObjectSize' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + desc.MaxDocumentSize = uint32(i64) + case "maxMessageSizeBytes": + i64, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'maxMessageSizeBytes' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + desc.MaxMessageSize = uint32(i64) + case "maxWriteBatchSize": + i64, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'maxWriteBatchSize' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + desc.MaxBatchCount = uint32(i64) + case "me": + me, ok := element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'me' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + desc.CanonicalAddr = address.Address(me).Canonicalize() + case "maxWireVersion": + verMax, ok := element.Value().AsInt64OK() + versionRange.Max = int32(verMax) + if !ok { + desc.LastError = fmt.Errorf("expected 'maxWireVersion' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + case "minWireVersion": + verMin, ok := element.Value().AsInt64OK() + versionRange.Min = int32(verMin) + if !ok { + desc.LastError = fmt.Errorf("expected 'minWireVersion' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + case "msg": + msg, ok = element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'msg' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + case "ok": + okay, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'ok' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + if okay != 1 { + desc.LastError = errors.New("not ok") + return desc + } + case "passives": + var err error + desc.Passives, err = stringSliceFromRawElement(element) + if err != nil { + desc.LastError = err + return desc + } + case "passive": + desc.Passive, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'passive' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "primary": + primary, ok := element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'primary' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + desc.Primary = address.Address(primary) + case "readOnly": + desc.ReadOnly, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'readOnly' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "secondary": + secondary, ok = element.Value().BooleanOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'secondary' to be a boolean but it's a BSON %s", element.Value().Type) + return desc + } + case "serviceId": + oid, ok := element.Value().ObjectIDOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'serviceId' to be an ObjectId but it's a BSON %s", element.Value().Type) + } + desc.ServiceID = &oid + case "setName": + desc.SetName, ok = element.Value().StringValueOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'setName' to be a string but it's a BSON %s", element.Value().Type) + return desc + } + case "setVersion": + i64, ok := element.Value().AsInt64OK() + if !ok { + desc.LastError = fmt.Errorf("expected 'setVersion' to be an integer but it's a BSON %s", element.Value().Type) + return desc + } + desc.SetVersion = uint32(i64) + case "tags": + m, err := decodeStringMap(element, "tags") + if err != nil { + desc.LastError = err + return desc + } + desc.Tags = tag.NewTagSetFromMap(m) + case "topologyVersion": + doc, ok := element.Value().DocumentOK() + if !ok { + desc.LastError = fmt.Errorf("expected 'topologyVersion' to be a document but it's a BSON %s", element.Value().Type) + return desc + } + + desc.TopologyVersion, err = NewTopologyVersion(doc) + if err != nil { + desc.LastError = err + return desc + } + } + } + + for _, host := range desc.Hosts { + desc.Members = append(desc.Members, address.Address(host).Canonicalize()) + } + + for _, passive := range desc.Passives { + desc.Members = append(desc.Members, address.Address(passive).Canonicalize()) + } + + for _, arbiter := range desc.Arbiters { + desc.Members = append(desc.Members, address.Address(arbiter).Canonicalize()) + } + + desc.Kind = description.ServerKindStandalone + + switch { + case isReplicaSet: + desc.Kind = description.ServerKindRSGhost + case desc.SetName != "": + switch { + case isWritablePrimary: + desc.Kind = description.ServerKindRSPrimary + case hidden: + desc.Kind = description.ServerKindRSMember + case secondary: + desc.Kind = description.ServerKindRSSecondary + case arbiterOnly: + desc.Kind = description.ServerKindRSArbiter + default: + desc.Kind = description.ServerKindRSMember + } + case msg == "isdbgrid": + desc.Kind = description.ServerKindMongos + } + + desc.WireVersion = &versionRange + + return desc +} + +func sliceStringEqual(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v != b[i] { + return false + } + } + + return true +} diff --git a/drivers/mongov2/internal/driverutil/hello.go b/drivers/mongov2/internal/driverutil/hello.go new file mode 100644 index 0000000..18a70f0 --- /dev/null +++ b/drivers/mongov2/internal/driverutil/hello.go @@ -0,0 +1,128 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "os" + "strings" +) + +const AwsLambdaPrefix = "AWS_Lambda_" + +const ( + // FaaS environment variable names + + // EnvVarAWSExecutionEnv is the AWS Execution environment variable. + EnvVarAWSExecutionEnv = "AWS_EXECUTION_ENV" + // EnvVarAWSLambdaRuntimeAPI is the AWS Lambda runtime API variable. + EnvVarAWSLambdaRuntimeAPI = "AWS_LAMBDA_RUNTIME_API" + // EnvVarFunctionsWorkerRuntime is the functions worker runtime variable. + EnvVarFunctionsWorkerRuntime = "FUNCTIONS_WORKER_RUNTIME" + // EnvVarKService is the K Service variable. + EnvVarKService = "K_SERVICE" + // EnvVarFunctionName is the function name variable. + EnvVarFunctionName = "FUNCTION_NAME" + // EnvVarVercel is the Vercel variable. + EnvVarVercel = "VERCEL" + // EnvVarK8s is the K8s variable. + EnvVarK8s = "KUBERNETES_SERVICE_HOST" +) + +const ( + // FaaS environment variable names + + // EnvVarAWSRegion is the AWS region variable. + EnvVarAWSRegion = "AWS_REGION" + // EnvVarAWSLambdaFunctionMemorySize is the AWS Lambda function memory size variable. + EnvVarAWSLambdaFunctionMemorySize = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE" + // EnvVarFunctionMemoryMB is the function memory in megabytes variable. + EnvVarFunctionMemoryMB = "FUNCTION_MEMORY_MB" + // EnvVarFunctionTimeoutSec is the function timeout in seconds variable. + EnvVarFunctionTimeoutSec = "FUNCTION_TIMEOUT_SEC" + // EnvVarFunctionRegion is the function region variable. + EnvVarFunctionRegion = "FUNCTION_REGION" + // EnvVarVercelRegion is the Vercel region variable. + EnvVarVercelRegion = "VERCEL_REGION" +) + +const ( + // FaaS environment names used by the client + + // EnvNameAWSLambda is the AWS Lambda environment name. + EnvNameAWSLambda = "aws.lambda" + // EnvNameAzureFunc is the Azure Function environment name. + EnvNameAzureFunc = "azure.func" + // EnvNameGCPFunc is the Google Cloud Function environment name. + EnvNameGCPFunc = "gcp.func" + // EnvNameVercel is the Vercel environment name. + EnvNameVercel = "vercel" +) + +// GetFaasEnvName parses the FaaS environment variable name and returns the +// corresponding name used by the client. If none of the variables or variables +// for multiple names are populated the client.env value MUST be entirely +// omitted. When variables for multiple "client.env.name" values are present, +// "vercel" takes precedence over "aws.lambda"; any other combination MUST cause +// "client.env" to be entirely omitted. +func GetFaasEnvName() string { + envVars := []string{ + EnvVarAWSExecutionEnv, + EnvVarAWSLambdaRuntimeAPI, + EnvVarFunctionsWorkerRuntime, + EnvVarKService, + EnvVarFunctionName, + EnvVarVercel, + } + + // If none of the variables are populated the client.env value MUST be + // entirely omitted. + names := make(map[string]struct{}) + + for _, envVar := range envVars { + val := os.Getenv(envVar) + if val == "" { + continue + } + + var name string + + switch envVar { + case EnvVarAWSExecutionEnv: + if !strings.HasPrefix(val, AwsLambdaPrefix) { + continue + } + + name = EnvNameAWSLambda + case EnvVarAWSLambdaRuntimeAPI: + name = EnvNameAWSLambda + case EnvVarFunctionsWorkerRuntime: + name = EnvNameAzureFunc + case EnvVarKService, EnvVarFunctionName: + name = EnvNameGCPFunc + case EnvVarVercel: + // "vercel" takes precedence over "aws.lambda". + delete(names, EnvNameAWSLambda) + + name = EnvNameVercel + } + + names[name] = struct{}{} + if len(names) > 1 { + // If multiple names are populated the client.env value + // MUST be entirely omitted. + names = nil + + break + } + } + + for name := range names { + return name + } + + return "" +} diff --git a/drivers/mongov2/internal/driverutil/operation.go b/drivers/mongov2/internal/driverutil/operation.go new file mode 100644 index 0000000..3270431 --- /dev/null +++ b/drivers/mongov2/internal/driverutil/operation.go @@ -0,0 +1,31 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +// Operation Names should be sourced from the command reference documentation: +// https://www.mongodb.com/docs/manual/reference/command/ +const ( + AbortTransactionOp = "abortTransaction" // AbortTransactionOp is the name for aborting a transaction + AggregateOp = "aggregate" // AggregateOp is the name for aggregating + CommitTransactionOp = "commitTransaction" // CommitTransactionOp is the name for committing a transaction + CountOp = "count" // CountOp is the name for counting + CreateOp = "create" // CreateOp is the name for creating + CreateIndexesOp = "createIndexes" // CreateIndexesOp is the name for creating indexes + DeleteOp = "delete" // DeleteOp is the name for deleting + DistinctOp = "distinct" // DistinctOp is the name for distinct + DropOp = "drop" // DropOp is the name for dropping + DropDatabaseOp = "dropDatabase" // DropDatabaseOp is the name for dropping a database + DropIndexesOp = "dropIndexes" // DropIndexesOp is the name for dropping indexes + EndSessionsOp = "endSessions" // EndSessionsOp is the name for ending sessions + FindAndModifyOp = "findAndModify" // FindAndModifyOp is the name for finding and modifying + FindOp = "find" // FindOp is the name for finding + InsertOp = "insert" // InsertOp is the name for inserting + ListCollectionsOp = "listCollections" // ListCollectionsOp is the name for listing collections + ListIndexesOp = "listIndexes" // ListIndexesOp is the name for listing indexes + ListDatabasesOp = "listDatabases" // ListDatabasesOp is the name for listing databases + UpdateOp = "update" // UpdateOp is the name for updating +) diff --git a/drivers/mongov2/internal/failpoint/failpoint.go b/drivers/mongov2/internal/failpoint/failpoint.go new file mode 100644 index 0000000..9fe25ba --- /dev/null +++ b/drivers/mongov2/internal/failpoint/failpoint.go @@ -0,0 +1,63 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package failpoint + +import ( + "go.mongodb.org/mongo-driver/v2/bson" +) + +const ( + // ModeAlwaysOn is the fail point mode that enables the fail point for an + // indefinite number of matching commands. + ModeAlwaysOn = "alwaysOn" + + // ModeOff is the fail point mode that disables the fail point. + ModeOff = "off" +) + +// FailPoint is used to configure a server fail point. It is intended to be +// passed as the command argument to RunCommand. +// +// For more information about fail points, see +// https://github.com/mongodb/specifications/tree/HEAD/source/transactions/tests#server-fail-point +type FailPoint struct { + ConfigureFailPoint string `bson:"configureFailPoint"` + // Mode should be a string, FailPointMode, or map[string]interface{} + Mode interface{} `bson:"mode"` + Data Data `bson:"data"` +} + +// Mode configures when a fail point will be enabled. It is used to set the +// FailPoint.Mode field. +type Mode struct { + Times int32 `bson:"times"` + Skip int32 `bson:"skip"` +} + +// Data configures how a fail point will behave. It is used to set the +// FailPoint.Data field. +type Data struct { + FailCommands []string `bson:"failCommands,omitempty"` + CloseConnection bool `bson:"closeConnection,omitempty"` + ErrorCode int32 `bson:"errorCode,omitempty"` + FailBeforeCommitExceptionCode int32 `bson:"failBeforeCommitExceptionCode,omitempty"` + ErrorLabels *[]string `bson:"errorLabels,omitempty"` + WriteConcernError *WriteConcernError `bson:"writeConcernError,omitempty"` + BlockConnection bool `bson:"blockConnection,omitempty"` + BlockTimeMS int32 `bson:"blockTimeMS,omitempty"` + AppName string `bson:"appName,omitempty"` +} + +// WriteConcernError is the write concern error to return when the fail point is +// triggered. It is used to set the FailPoint.Data.WriteConcernError field. +type WriteConcernError struct { + Code int32 `bson:"code"` + Name string `bson:"codeName"` + Errmsg string `bson:"errmsg"` + ErrorLabels *[]string `bson:"errorLabels,omitempty"` + ErrInfo bson.Raw `bson:"errInfo,omitempty"` +} diff --git a/drivers/mongov2/internal/integtest/integtest.go b/drivers/mongov2/internal/integtest/integtest.go new file mode 100644 index 0000000..85a567e --- /dev/null +++ b/drivers/mongov2/internal/integtest/integtest.go @@ -0,0 +1,295 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package integtest + +import ( + "context" + "errors" + "fmt" + "math" + "os" + "reflect" + "strconv" + "strings" + "sync" + "testing" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/serverselector" + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" +) + +var connectionString *connstring.ConnString +var connectionStringOnce sync.Once +var connectionStringErr error +var liveTopology *topology.Topology +var liveTopologyOnce sync.Once +var liveTopologyErr error + +// AddOptionsToURI appends connection string options to a URI. +func AddOptionsToURI(uri string, opts ...string) string { + if !strings.ContainsRune(uri, '?') { + if uri[len(uri)-1] != '/' { + uri += "/" + } + + uri += "?" + } else { + uri += "&" + } + + for _, opt := range opts { + uri += opt + } + + return uri +} + +// AddTLSConfigToURI checks for the environmental variable indicating that the tests are being run +// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. +func AddTLSConfigToURI(uri string) string { + caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") + if len(caFile) == 0 { + return uri + } + + return AddOptionsToURI(uri, "ssl=true&sslCertificateAuthorityFile=", caFile) +} + +// AddCompressorToURI checks for the environment variable indicating that the tests are being run with compression +// enabled. If so, it returns a new URI with the necessary configuration +func AddCompressorToURI(uri string) string { + comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") + if len(comp) == 0 { + return uri + } + + return AddOptionsToURI(uri, "compressors=", comp) +} + +// AddTestServerAPIVersion adds the latest server API version in a ServerAPIOptions to passed-in opts. +func AddTestServerAPIVersion(opts *options.ClientOptions) { + if os.Getenv("REQUIRE_API_VERSION") == "true" { + opts.SetServerAPIOptions(options.ServerAPI(driver.TestServerAPIVersion)) + } +} + +// MonitoredTopology returns a new topology with the command monitor attached +func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonitor) *topology.Topology { + uri, err := MongoDBURI() + if err != nil { + t.Fatal(err) + } + + opts := options.Client().ApplyURI(uri).SetMonitor(monitor) + + cfg, err := topology.NewConfig(opts, nil) + if err != nil { + t.Fatal(err) + } + + monitoredTopology, err := topology.New(cfg) + if err != nil { + t.Fatal(err) + } else { + _ = monitoredTopology.Connect() + + err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). + Database(dbName).ServerSelector(&serverselector.Write{}).Deployment(monitoredTopology).Execute(context.Background()) + + require.NoError(t, err) + } + + return monitoredTopology +} + +// Topology gets the globally configured topology. +func Topology(t *testing.T) *topology.Topology { + uri, err := MongoDBURI() + require.NoError(t, err, "error constructing mongodb URI: %v", err) + + opts := options.Client().ApplyURI(uri) + + cfg, err := topology.NewConfig(opts, nil) + require.NoError(t, err, "error constructing topology config: %v", err) + + liveTopologyOnce.Do(func() { + var err error + liveTopology, err = topology.New(cfg) + if err != nil { + liveTopologyErr = err + } else { + _ = liveTopology.Connect() + + err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). + Database(DBName(t)).ServerSelector(&serverselector.Write{}). + Deployment(liveTopology).Execute(context.Background()) + require.NoError(t, err) + } + }) + + if liveTopologyErr != nil { + t.Fatal(liveTopologyErr) + } + + return liveTopology +} + +// TopologyWithCredential takes an "options.Credential" object and returns a connected topology. +func TopologyWithCredential(t *testing.T, credential options.Credential) *topology.Topology { + uri, err := MongoDBURI() + if err != nil { + t.Fatalf("error constructing mongodb URI: %v", err) + } + + opts := options.Client().ApplyURI(uri).SetAuth(credential) + + cfg, err := topology.NewConfig(opts, nil) + if err != nil { + t.Fatalf("error constructing topology config: %v", err) + } + topology, err := topology.New(cfg) + if err != nil { + t.Fatal("Could not construct topology") + } + err = topology.Connect() + if err != nil { + t.Fatal("Could not start topology connection") + } + return topology +} + +// ColName gets a collection name that should be unique +// to the currently executing test. +func ColName(t *testing.T) string { + // Get this indirectly to avoid copying a mutex + v := reflect.Indirect(reflect.ValueOf(t)) + name := v.FieldByName("name") + return name.String() +} + +// MongoDBURI will construct the MongoDB URI from the MONGODB_URI environment variable for testing. The default host is +// "localhost" and the default port is "27017" +func MongoDBURI() (string, error) { + uri := os.Getenv("MONGODB_URI") + if uri == "" { + uri = "mongodb://localhost:27017" + } + + uri = AddTLSConfigToURI(uri) + uri = AddCompressorToURI(uri) + uri, err := AddServerlessAuthCredentials(uri) + return uri, err +} + +// AddServerlessAuthCredentials will attempt to construct the serverless auth credentials for a URI. +func AddServerlessAuthCredentials(uri string) (string, error) { + if os.Getenv("SERVERLESS") != "serverless" { + return uri, nil + } + user := os.Getenv("SERVERLESS_ATLAS_USER") + if user == "" { + return "", fmt.Errorf("serverless expects SERVERLESS_ATLAS_USER to be set") + } + password := os.Getenv("SERVERLESS_ATLAS_PASSWORD") + if password == "" { + return "", fmt.Errorf("serverless expects SERVERLESS_ATLAS_PASSWORD to be set") + } + + var scheme string + // remove the scheme + switch { + case strings.HasPrefix(uri, "mongodb+srv://"): + scheme = "mongodb+srv://" + case strings.HasPrefix(uri, "mongodb://"): + scheme = "mongodb://" + default: + return "", errors.New(`scheme must be "mongodb" or "mongodb+srv"`) + } + + uri = scheme + user + ":" + password + "@" + uri[len(scheme):] + return uri, nil +} + +// ConnString gets the globally configured connection string. +func ConnString(t *testing.T) *connstring.ConnString { + connectionStringOnce.Do(func() { + uri, err := MongoDBURI() + require.NoError(t, err, "error constructing mongodb URI: %v", err) + + connectionString, err = connstring.ParseAndValidate(uri) + if err != nil { + connectionStringErr = err + } + }) + if connectionStringErr != nil { + t.Fatal(connectionStringErr) + } + + return connectionString +} + +func GetConnString() (*connstring.ConnString, error) { + mongodbURI := os.Getenv("MONGODB_URI") + if mongodbURI == "" { + mongodbURI = "mongodb://localhost:27017" + } + + mongodbURI = AddTLSConfigToURI(mongodbURI) + + cs, err := connstring.ParseAndValidate(mongodbURI) + if err != nil { + return nil, err + } + + return cs, nil +} + +// DBName gets the globally configured database name. +func DBName(t *testing.T) string { + return GetDBName(ConnString(t)) +} + +func GetDBName(cs *connstring.ConnString) string { + if cs.Database != "" { + return cs.Database + } + + return fmt.Sprintf("mongo-go-driver-%d", os.Getpid()) +} + +// CompareVersions compares two version number strings (i.e. positive integers separated by +// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is +// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. +// +// Returns a positive int if version1 is greater than version2, a negative int if version1 is less +// than version2, and 0 if version1 is equal to version2. +func CompareVersions(t *testing.T, v1 string, v2 string) int { + n1 := strings.Split(v1, ".") + n2 := strings.Split(v2, ".") + + for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { + i1, err := strconv.Atoi(n1[i]) + require.NoError(t, err) + + i2, err := strconv.Atoi(n2[i]) + require.NoError(t, err) + + difference := i1 - i2 + if difference != 0 { + return difference + } + } + + return 0 +} diff --git a/drivers/mongov2/internal/mongoutil/mongoutil.go b/drivers/mongov2/internal/mongoutil/mongoutil.go new file mode 100644 index 0000000..0345b96 --- /dev/null +++ b/drivers/mongov2/internal/mongoutil/mongoutil.go @@ -0,0 +1,85 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongoutil + +import ( + "reflect" + + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// NewOptions will functionally merge a slice of mongo.Options in a +// "last-one-wins" manner, where nil options are ignored. +func NewOptions[T any](opts ...options.Lister[T]) (*T, error) { + args := new(T) + for _, opt := range opts { + if opt == nil || reflect.ValueOf(opt).IsNil() { + // Do nothing if the option is nil or if opt is nil but implicitly cast as + // an Options interface by the NewArgsFromOptions function. The latter + // case would look something like this: + continue + } + + for _, setArgs := range opt.List() { + if setArgs == nil { + continue + } + + if err := setArgs(args); err != nil { + return nil, err + } + } + } + return args, nil +} + +// OptionsLister implements an options.SetterLister object for an arbitrary +// options type. +type OptionsLister[T any] struct { + Options *T // Arguments to set on the option type + Callback func(*T) error // A callback for further modification +} + +// List will re-assign the entire argument option to the Args field +// defined on opts. If a callback exists, that function will be executed to +// further modify the arguments. +func (opts *OptionsLister[T]) List() []func(*T) error { + return []func(*T) error{ + func(args *T) error { + if opts.Options != nil { + *args = *opts.Options + } + + if opts.Callback != nil { + return opts.Callback(args) + } + + return nil + }, + } +} + +// NewOptionsLister will construct a SetterLister from the provided Options +// object. +func NewOptionsLister[T any](args *T, callback func(*T) error) *OptionsLister[T] { + return &OptionsLister[T]{Options: args, Callback: callback} +} + +// AuthFromURI will create a Credentials object given the provided URI. +func AuthFromURI(uri string) (*options.Credential, error) { + opts := options.Client().ApplyURI(uri) + + return opts.Auth, nil +} + +// HostsFromURI will parse the hosts in the URI and return them as a slice of +// strings. +func HostsFromURI(uri string) ([]string, error) { + opts := options.Client().ApplyURI(uri) + + return opts.Hosts, nil +} diff --git a/drivers/mongov2/internal/mongoutil/mongoutil_test.go b/drivers/mongov2/internal/mongoutil/mongoutil_test.go new file mode 100644 index 0000000..661ee5f --- /dev/null +++ b/drivers/mongov2/internal/mongoutil/mongoutil_test.go @@ -0,0 +1,34 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongoutil + +import ( + "strings" + "testing" + + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +func BenchmarkNewOptions(b *testing.B) { + b.Run("reflect.ValueOf is always called", func(b *testing.B) { + opts := make([]options.Lister[options.FindOptions], b.N) + + // Create a huge string to see if we can force reflect.ValueOf to use heap + // over stack. + size := 16 * 1024 * 1024 + str := strings.Repeat("a", size) + + for i := 0; i < b.N; i++ { + opts[i] = options.Find().SetComment(str).SetHint("y").SetMin(1).SetMax(2) + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = NewOptions[options.FindOptions](opts...) + } + }) +} diff --git a/drivers/mongov2/internal/mtest/csfle_enabled.go b/drivers/mongov2/internal/mtest/csfle_enabled.go new file mode 100644 index 0000000..588e9ad --- /dev/null +++ b/drivers/mongov2/internal/mtest/csfle_enabled.go @@ -0,0 +1,16 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build cse +// +build cse + +package mtest + +// IsCSFLEEnabled returns true if driver is built with Client Side Field Level Encryption support. +// Client Side Field Level Encryption support is enabled with the cse build tag. +func IsCSFLEEnabled() bool { + return true +} diff --git a/drivers/mongov2/internal/mtest/csfle_not_enabled.go b/drivers/mongov2/internal/mtest/csfle_not_enabled.go new file mode 100644 index 0000000..289cf5c --- /dev/null +++ b/drivers/mongov2/internal/mtest/csfle_not_enabled.go @@ -0,0 +1,16 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build !cse +// +build !cse + +package mtest + +// IsCSFLEEnabled returns true if driver is built with Client Side Field Level Encryption support. +// Client Side Field Level Encryption support is enabled with the cse build tag. +func IsCSFLEEnabled() bool { + return false +} diff --git a/drivers/mongov2/internal/mtest/deployment_helpers.go b/drivers/mongov2/internal/mtest/deployment_helpers.go new file mode 100644 index 0000000..8683303 --- /dev/null +++ b/drivers/mongov2/internal/mtest/deployment_helpers.go @@ -0,0 +1,120 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "go.mongodb.org/mongo-driver/v2/bson" +) + +// BatchIdentifier specifies the keyword to identify the batch in a cursor response. +type BatchIdentifier string + +// These constants specify valid values for BatchIdentifier. +const ( + FirstBatch BatchIdentifier = "firstBatch" + NextBatch BatchIdentifier = "nextBatch" +) + +// CommandError is a representation of a command error from the server. +type CommandError struct { + Code int32 + Message string + Name string + Labels []string +} + +// WriteError is a representation of a write error from the server. +type WriteError struct { + Index int + Code int + Message string +} + +// WriteConcernError is a representation of a write concern error from the server. +type WriteConcernError struct { + Name string `bson:"codeName"` + Code int `bson:"code"` + Message string `bson:"errmsg"` + Details bson.Raw `bson:"errInfo"` +} + +// CreateCursorResponse creates a response for a cursor command. +func CreateCursorResponse(cursorID int64, ns string, identifier BatchIdentifier, batch ...bson.D) bson.D { + batchArr := bson.A{} + for _, doc := range batch { + batchArr = append(batchArr, doc) + } + + return bson.D{ + {"ok", 1}, + {"cursor", bson.D{ + {"id", cursorID}, + {"ns", ns}, + {string(identifier), batchArr}, + }}, + } +} + +// CreateCommandErrorResponse creates a response with a command error. +func CreateCommandErrorResponse(ce CommandError) bson.D { + res := bson.D{ + {"ok", 0}, + {"code", ce.Code}, + {"errmsg", ce.Message}, + {"codeName", ce.Name}, + } + if len(ce.Labels) > 0 { + var labelsArr bson.A + for _, label := range ce.Labels { + labelsArr = append(labelsArr, label) + } + res = append(res, bson.E{Key: "errorLabels", Value: labelsArr}) + } + return res +} + +// CreateWriteErrorsResponse creates a response with one or more write errors. +func CreateWriteErrorsResponse(writeErrorrs ...WriteError) bson.D { + arr := make(bson.A, len(writeErrorrs)) + for idx, we := range writeErrorrs { + arr[idx] = bson.D{ + {"index", we.Index}, + {"code", we.Code}, + {"errmsg", we.Message}, + } + } + + return bson.D{ + {"ok", 1}, + {"writeErrors", arr}, + } +} + +// CreateWriteConcernErrorResponse creates a response with a write concern error. +func CreateWriteConcernErrorResponse(wce WriteConcernError) bson.D { + wceDoc := bson.D{ + {"code", wce.Code}, + {"codeName", wce.Name}, + {"errmsg", wce.Message}, + } + if len(wce.Details) > 0 { + wceDoc = append(wceDoc, bson.E{Key: "errInfo", Value: wce.Details}) + } + + return bson.D{ + {"ok", 1}, + {"writeConcernError", wceDoc}, + } +} + +// CreateSuccessResponse creates a response for a successful operation with the given elements. +func CreateSuccessResponse(elems ...bson.E) bson.D { + res := bson.D{ + {"ok", 1}, + } + return append(res, elems...) +} diff --git a/drivers/mongov2/internal/mtest/doc.go b/drivers/mongov2/internal/mtest/doc.go new file mode 100644 index 0000000..9d4ae6f --- /dev/null +++ b/drivers/mongov2/internal/mtest/doc.go @@ -0,0 +1,9 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package mtest is unstable and there is no backward compatibility guarantee. +// It is experimental and subject to change. +package mtest diff --git a/drivers/mongov2/internal/mtest/global_state.go b/drivers/mongov2/internal/mtest/global_state.go new file mode 100644 index 0000000..1f0d0a9 --- /dev/null +++ b/drivers/mongov2/internal/mtest/global_state.go @@ -0,0 +1,96 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "context" + "fmt" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" +) + +// AuthEnabled returns whether or not the cluster requires auth. +func AuthEnabled() bool { + return testContext.authEnabled +} + +// SSLEnabled returns whether or not the cluster requires SSL. +func SSLEnabled() bool { + return testContext.sslEnabled +} + +// ClusterTopologyKind returns the topology kind of the cluster under test. +func ClusterTopologyKind() TopologyKind { + return testContext.topoKind +} + +// ClusterURI returns the connection string for the cluster. +func ClusterURI() string { + return testContext.connString.Original +} + +// Serverless returns whether the test is running against a serverless instance. +func Serverless() bool { + return testContext.serverless +} + +// SingleMongosLoadBalancerURI returns the URI for a load balancer fronting a single mongos. This will only be set +// if the cluster is load balanced. +func SingleMongosLoadBalancerURI() string { + return testContext.singleMongosLoadBalancerURI +} + +// MultiMongosLoadBalancerURI returns the URI for a load balancer fronting multiple mongoses. This will only be set +// if the cluster is load balanced. +func MultiMongosLoadBalancerURI() string { + return testContext.multiMongosLoadBalancerURI +} + +// ClusterConnString returns the parsed ConnString for the cluster. +func ClusterConnString() *connstring.ConnString { + return testContext.connString +} + +// GlobalClient returns a Client connected to the cluster configured with read concern majority, write concern majority, +// and read preference primary. +func GlobalClient() *mongo.Client { + return testContext.client +} + +// GlobalTopology returns a Topology that's connected to the cluster. +func GlobalTopology() *topology.Topology { + return testContext.topo +} + +// ServerVersion returns the server version of the cluster. This assumes that all nodes in the cluster have the same +// version. +func ServerVersion() string { + return testContext.serverVersion +} + +// SetFailPoint configures the provided fail point on the cluster under test using the provided Client. +func SetFailPoint(fp failpoint.FailPoint, client *mongo.Client) error { + admin := client.Database("admin") + if err := admin.RunCommand(context.Background(), fp).Err(); err != nil { + return fmt.Errorf("error creating fail point: %w", err) + } + return nil +} + +// SetRawFailPoint configures the fail point represented by the fp parameter on the cluster under test using the +// provided Client +func SetRawFailPoint(fp bson.Raw, client *mongo.Client) error { + admin := client.Database("admin") + if err := admin.RunCommand(context.Background(), fp).Err(); err != nil { + return fmt.Errorf("error creating fail point: %w", err) + } + return nil +} diff --git a/drivers/mongov2/internal/mtest/mongotest.go b/drivers/mongov2/internal/mtest/mongotest.go new file mode 100644 index 0000000..f510831 --- /dev/null +++ b/drivers/mongov2/internal/mtest/mongotest.go @@ -0,0 +1,874 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/csfle" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mongoutil" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readconcern" + "go.mongodb.org/mongo-driver/v2/mongo/readpref" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" +) + +var ( + // MajorityWc is the majority write concern. + MajorityWc = writeconcern.Majority() + // PrimaryRp is the primary read preference. + PrimaryRp = readpref.Primary() + // SecondaryRp is the secondary read preference. + SecondaryRp = readpref.Secondary() + // LocalRc is the local read concern + LocalRc = readconcern.Local() + // MajorityRc is the majority read concern + MajorityRc = readconcern.Majority() +) + +const ( + namespaceExistsErrCode int32 = 48 +) + +// T is a wrapper around testing.T. +type T struct { + // connsCheckedOut is the net number of connections checked out during test execution. + // It must be accessed using the atomic package and should be at the beginning of the struct. + // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // - suggested layout: https://go101.org/article/memory-layout.html + connsCheckedOut int64 + + *testing.T + + // members for only this T instance + createClient *bool + createCollection *bool + runOn []RunOnBlock + mockDeployment *drivertest.MockDeployment // nil if the test is not being run against a mock + mockResponses []bson.D + createdColls []*Collection // collections created in this test + proxyDialer *proxyDialer + dbName, collName string + failPointNames []string + minServerVersion string + maxServerVersion string + validTopologies []TopologyKind + auth *bool + enterprise *bool + dataLake *bool + ssl *bool + collCreateOpts *options.CreateCollectionOptionsBuilder + requireAPIVersion *bool + + // options copied to sub-tests + clientType ClientType + clientOpts *options.ClientOptions + collOpts *options.CollectionOptionsBuilder + shareClient *bool + + baseOpts *Options // used to create subtests + + // command monitoring channels + monitorLock sync.Mutex + started []*event.CommandStartedEvent + succeeded []*event.CommandSucceededEvent + failed []*event.CommandFailedEvent + + Client *mongo.Client + DB *mongo.Database + Coll *mongo.Collection +} + +func newT(wrapped *testing.T, opts ...*Options) *T { + t := &T{ + T: wrapped, + } + for _, opt := range opts { + for _, optFn := range opt.optFuncs { + optFn(t) + } + } + + if err := t.verifyConstraints(); err != nil { + t.Skipf("skipping due to environmental constraints: %v", err) + } + + if t.collName == "" { + t.collName = t.Name() + } + if t.dbName == "" { + t.dbName = TestDb + } + t.collName = sanitizeCollectionName(t.dbName, t.collName) + + // create a set of base options for sub-tests + t.baseOpts = NewOptions().ClientOptions(t.clientOpts).CollectionOptions(t.collOpts).ClientType(t.clientType) + if t.shareClient != nil { + t.baseOpts.ShareClient(*t.shareClient) + } + + return t +} + +// New creates a new T instance with the given options. If the current environment does not satisfy constraints +// specified in the options, the test will be skipped automatically. +func New(wrapped *testing.T, opts ...*Options) *T { + // All tests that use mtest.New() are expected to be integration tests, so skip them when the + // -short flag is included in the "go test" command. + if testing.Short() { + wrapped.Skip("skipping mtest integration test in short mode") + } + + t := newT(wrapped, opts...) + + // only create a client if it needs to be shared in sub-tests + // otherwise, a new client will be created for each subtest + if t.shareClient != nil && *t.shareClient { + t.createTestClient() + } + + wrapped.Cleanup(t.cleanup) + + return t +} + +// cleanup cleans up any resources associated with a T. It is intended to be +// called by [testing.T.Cleanup]. +func (t *T) cleanup() { + if t.Client == nil { + return + } + + // only clear collections and fail points if the test is not running against a mock + if t.clientType != Mock { + t.ClearCollections() + t.ClearFailPoints() + } + + // always disconnect the client regardless of clientType because Client.Disconnect will work against + // all deployments + _ = t.Client.Disconnect(context.Background()) +} + +// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the +// given name which is available to the callback through the T.Coll variable and is dropped after the callback +// returns. +func (t *T) Run(name string, callback func(mt *T)) { + t.RunOpts(name, NewOptions(), callback) +} + +// RunOpts creates a new T instance for a sub-test with the given options. If the current environment does not satisfy +// constraints specified in the options, the new sub-test will be skipped automatically. If the test is not skipped, +// the callback will be run with the new T instance. RunOpts creates a new collection with the given name which is +// available to the callback through the T.Coll variable and is dropped after the callback returns. +func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) { + t.T.Run(name, func(wrapped *testing.T) { + sub := newT(wrapped, t.baseOpts, opts) + + // add any mock responses for this test + if sub.clientType == Mock && len(sub.mockResponses) > 0 { + sub.AddMockResponses(sub.mockResponses...) + } + + // for shareClient, inherit the client from the parent + if sub.shareClient != nil && *sub.shareClient && sub.clientType == t.clientType { + sub.Client = t.Client + } + // only create a client if not already set + if sub.Client == nil { + if sub.createClient == nil || *sub.createClient { + sub.createTestClient() + } + } + // create a collection for this test + if sub.Client != nil { + sub.createTestCollection() + } + + // defer dropping all collections if the test is using a client + defer func() { + if sub.Client == nil { + return + } + + // store number of sessions and connections checked out here but assert that they're equal to 0 after + // cleaning up test resources to make sure resources are always cleared + sessions := sub.Client.NumberSessionsInProgress() + conns := sub.NumberConnectionsCheckedOut() + + if sub.clientType != Mock { + sub.ClearFailPoints() + sub.ClearCollections() + } + // only disconnect client if it's not being shared + if sub.shareClient == nil || !*sub.shareClient { + _ = sub.Client.Disconnect(context.Background()) + } + assert.Equal(sub, 0, sessions, "%v sessions checked out", sessions) + assert.Equal(sub, 0, conns, "%v connections checked out", conns) + }() + + // clear any events that may have happened during setup and run the test + sub.ClearEvents() + callback(sub) + }) +} + +// AddMockResponses adds responses to be returned by the mock deployment. This should only be used if T is being run +// against a mock deployment. +func (t *T) AddMockResponses(responses ...bson.D) { + t.mockDeployment.AddResponses(responses...) +} + +// ClearMockResponses clears all responses in the mock deployment. +func (t *T) ClearMockResponses() { + t.mockDeployment.ClearResponses() +} + +// GetStartedEvent returns the least recent CommandStartedEvent, or nil if one is not present. +// This can only be called once per event. +func (t *T) GetStartedEvent() *event.CommandStartedEvent { + if len(t.started) == 0 { + return nil + } + e := t.started[0] + t.started = t.started[1:] + return e +} + +// GetSucceededEvent returns the least recent CommandSucceededEvent, or nil if one is not present. +// This can only be called once per event. +func (t *T) GetSucceededEvent() *event.CommandSucceededEvent { + if len(t.succeeded) == 0 { + return nil + } + e := t.succeeded[0] + t.succeeded = t.succeeded[1:] + return e +} + +// GetFailedEvent returns the least recent CommandFailedEvent, or nil if one is not present. +// This can only be called once per event. +func (t *T) GetFailedEvent() *event.CommandFailedEvent { + if len(t.failed) == 0 { + return nil + } + e := t.failed[0] + t.failed = t.failed[1:] + return e +} + +// GetAllStartedEvents returns a slice of all CommandStartedEvent instances for this test. This can be called multiple +// times. +func (t *T) GetAllStartedEvents() []*event.CommandStartedEvent { + return t.started +} + +// GetAllSucceededEvents returns a slice of all CommandSucceededEvent instances for this test. This can be called multiple +// times. +func (t *T) GetAllSucceededEvents() []*event.CommandSucceededEvent { + return t.succeeded +} + +// GetAllFailedEvents returns a slice of all CommandFailedEvent instances for this test. This can be called multiple +// times. +func (t *T) GetAllFailedEvents() []*event.CommandFailedEvent { + return t.failed +} + +// FilterStartedEvents filters the existing CommandStartedEvent instances for this test using the provided filter +// callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite +// the list of events for this test and will therefore change the output of t.GetAllStartedEvents(). +func (t *T) FilterStartedEvents(filter func(*event.CommandStartedEvent) bool) { + var newEvents []*event.CommandStartedEvent + for _, evt := range t.started { + if filter(evt) { + newEvents = append(newEvents, evt) + } + } + t.started = newEvents +} + +// FilterSucceededEvents filters the existing CommandSucceededEvent instances for this test using the provided filter +// callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite +// the list of events for this test and will therefore change the output of t.GetAllSucceededEvents(). +func (t *T) FilterSucceededEvents(filter func(*event.CommandSucceededEvent) bool) { + var newEvents []*event.CommandSucceededEvent + for _, evt := range t.succeeded { + if filter(evt) { + newEvents = append(newEvents, evt) + } + } + t.succeeded = newEvents +} + +// FilterFailedEvents filters the existing CommandFailedEVent instances for this test using the provided filter +// callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite +// the list of events for this test and will therefore change the output of t.GetAllFailedEvents(). +func (t *T) FilterFailedEvents(filter func(*event.CommandFailedEvent) bool) { + var newEvents []*event.CommandFailedEvent + for _, evt := range t.failed { + if filter(evt) { + newEvents = append(newEvents, evt) + } + } + t.failed = newEvents +} + +// GetProxiedMessages returns the messages proxied to the server by the test. If the client type is not Proxy, this +// returns nil. +func (t *T) GetProxiedMessages() []*ProxyMessage { + if t.proxyDialer == nil { + return nil + } + return t.proxyDialer.Messages() +} + +// NumberConnectionsCheckedOut returns the number of connections checked out from the test Client. +func (t *T) NumberConnectionsCheckedOut() int { + return int(atomic.LoadInt64(&t.connsCheckedOut)) +} + +// ClearEvents clears the existing command monitoring events. +func (t *T) ClearEvents() { + t.started = t.started[:0] + t.succeeded = t.succeeded[:0] + t.failed = t.failed[:0] +} + +// ResetClient resets the existing client with the given options. If opts is nil, the existing options will be used. +// If t.Coll is not-nil, it will be reset to use the new client. Should only be called if the existing client is +// not nil. This will Disconnect the existing client but will not drop existing collections. To do so, ClearCollections +// must be called before calling ResetClient. +func (t *T) ResetClient(opts *options.ClientOptions) { + if opts != nil { + t.clientOpts = opts + } + + _ = t.Client.Disconnect(context.Background()) + t.createTestClient() + t.DB = t.Client.Database(t.dbName) + t.Coll = t.DB.Collection(t.collName, t.collOpts) + + for _, coll := range t.createdColls { + // If the collection was created using a different Client, it doesn't need to be reset. + if coll.hasDifferentClient { + continue + } + + // If the namespace is the same as t.Coll, we can use t.Coll. + if coll.created.Name() == t.collName && coll.created.Database().Name() == t.dbName { + coll.created = t.Coll + continue + } + + // Otherwise, reset the collection to use the new Client. + coll.created = t.Client.Database(coll.DB).Collection(coll.Name, coll.Opts) + } +} + +// Collection is used to configure a new collection created during a test. +type Collection struct { + Name string + DB string // defaults to mt.DB.Name() if not specified + Client *mongo.Client // defaults to mt.Client if not specified + Opts *options.CollectionOptionsBuilder + CreateOpts *options.CreateCollectionOptionsBuilder + ViewOn string + ViewPipeline interface{} + hasDifferentClient bool + created *mongo.Collection // the actual collection that was created +} + +// CreateCollection creates a new collection with the given configuration. The collection will be dropped after the test +// finishes running. If createOnServer is true, the function ensures that the collection has been created server-side +// by running the create command. The create command will appear in command monitoring channels. +func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collection { + if coll.DB == "" { + coll.DB = t.DB.Name() + } + if coll.Client == nil { + coll.Client = t.Client + } + coll.hasDifferentClient = coll.Client != t.Client + + db := coll.Client.Database(coll.DB) + + opts, err := mongoutil.NewOptions[options.CreateCollectionOptions](coll.CreateOpts) + require.NoError(t, err, "failed to construct options from builder") + + if coll.CreateOpts != nil && opts.EncryptedFields != nil { + // An encrypted collection consists of a data collection and three state collections. + // Aborted test runs may leave these collections. + // Drop all four collections to avoid a quiet failure to create all collections. + DropEncryptedCollection(t, db.Collection(coll.Name), opts.EncryptedFields) + } + + if createOnServer && t.clientType != Mock { + var err error + if coll.ViewOn != "" { + err = db.CreateView(context.Background(), coll.Name, coll.ViewOn, coll.ViewPipeline) + } else { + err = db.CreateCollection(context.Background(), coll.Name, coll.CreateOpts) + } + + // ignore ErrUnacknowledgedWrite. Client may be configured with unacknowledged write concern. + if err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite) { + // ignore NamespaceExists errors for idempotency + + var cmdErr mongo.CommandError + if !errors.As(err, &cmdErr) || cmdErr.Code != namespaceExistsErrCode { + t.Fatalf("error creating collection or view: %v on server: %v", coll.Name, err) + } + } + } + + coll.created = db.Collection(coll.Name, coll.Opts) + t.createdColls = append(t.createdColls, &coll) + return coll.created +} + +// DropEncryptedCollection drops a collection with EncryptedFields. +// The EncryptedFields option is not supported in Collection.Drop(). See GODRIVER-2413. +func DropEncryptedCollection(t *T, coll *mongo.Collection, encryptedFields interface{}) { + t.Helper() + + var efBSON bsoncore.Document + efBSON, err := bson.Marshal(encryptedFields) + assert.Nil(t, err, "error in Marshal: %v", err) + + // Drop the two encryption-related, associated collections: `escCollection` and `ecocCollection`. + // Drop ESCCollection. + escCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.Name(), csfle.EncryptedStateCollection) + assert.Nil(t, err, "error in getEncryptedStateCollectionName: %v", err) + err = coll.Database().Collection(escCollection).Drop(context.Background()) + assert.Nil(t, err, "error in Drop: %v", err) + + // Drop ECOCCollection. + ecocCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.Name(), csfle.EncryptedCompactionCollection) + assert.Nil(t, err, "error in getEncryptedStateCollectionName: %v", err) + err = coll.Database().Collection(ecocCollection).Drop(context.Background()) + assert.Nil(t, err, "error in Drop: %v", err) + + // Drop the data collection. + err = coll.Drop(context.Background()) + assert.Nil(t, err, "error in Drop: %v", err) +} + +// ClearCollections drops all collections previously created by this test. +func (t *T) ClearCollections() { + // Collections should not be dropped when testing against Atlas Data Lake because the data is pre-inserted. + if !testContext.dataLake { + for _, coll := range t.createdColls { + opts, err := mongoutil.NewOptions[options.CreateCollectionOptions](coll.CreateOpts) + require.NoError(t, err, "failed to construct options from builder") + + if coll.CreateOpts != nil && opts.EncryptedFields != nil { + DropEncryptedCollection(t, coll.created, opts.EncryptedFields) + } + + // It's possible that a collection could have an unacknowledged write + // concern, which could prevent it from being dropped for sharded + // clusters. We can resolve this by re-instantiating the collection with + // a majority write concern before dropping. + clonedColl := coll.created.Clone(options.Collection().SetWriteConcern(writeconcern.Majority())) + + _ = clonedColl.Drop(context.Background()) + } + } + t.createdColls = t.createdColls[:0] +} + +// SetFailPoint sets a fail point for the client associated with T. Commands to create the failpoint will appear +// in command monitoring channels. The fail point will automatically be disabled after this test has run. +func (t *T) SetFailPoint(fp failpoint.FailPoint) { + // ensure mode fields are int32 + if modeMap, ok := fp.Mode.(map[string]interface{}); ok { + var key string + var err error + + if times, ok := modeMap["times"]; ok { + key = "times" + modeMap["times"], err = t.interfaceToInt32(times) + } + if skip, ok := modeMap["skip"]; ok { + key = "skip" + modeMap["skip"], err = t.interfaceToInt32(skip) + } + + if err != nil { + t.Fatalf("error converting %s to int32: %v", key, err) + } + } + + if err := SetFailPoint(fp, t.Client); err != nil { + t.Fatal(err) + } + t.failPointNames = append(t.failPointNames, fp.ConfigureFailPoint) +} + +// SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This +// method assumes that the given document is in the form {configureFailPoint: , ...}. Commands to create +// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this +// test has run. +func (t *T) SetFailPointFromDocument(fp bson.Raw) { + if err := SetRawFailPoint(fp, t.Client); err != nil { + t.Fatal(err) + } + + name := fp.Index(0).Value().StringValue() + t.failPointNames = append(t.failPointNames, name) +} + +// TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes. +// This function does not create a fail point on the server. +func (t *T) TrackFailPoint(fpName string) { + t.failPointNames = append(t.failPointNames, fpName) +} + +// ClearFailPoints disables all previously set failpoints for this test. +func (t *T) ClearFailPoints() { + db := t.Client.Database("admin") + for _, fp := range t.failPointNames { + cmd := failpoint.FailPoint{ + ConfigureFailPoint: fp, + Mode: failpoint.ModeOff, + } + err := db.RunCommand(context.Background(), cmd).Err() + if err != nil { + t.Fatalf("error clearing fail point %s: %v", fp, err) + } + } + t.failPointNames = t.failPointNames[:0] +} + +// CloneDatabase modifies the default database for this test to match the given options. +func (t *T) CloneDatabase(opts *options.DatabaseOptionsBuilder) { + t.DB = t.Client.Database(t.dbName, opts) +} + +// CloneCollection modifies the default collection for this test to match the given options. +func (t *T) CloneCollection(opts *options.CollectionOptionsBuilder) { + t.Coll = t.Coll.Clone(opts) +} + +func sanitizeCollectionName(db string, coll string) string { + // Collections can't have "$" in their names, so we substitute it with "%". + coll = strings.ReplaceAll(coll, "$", "%") + + // Namespaces can only have 120 bytes max. + if len(db+"."+coll) >= 120 { + // coll len must be <= remaining + remaining := 120 - (len(db) + 1) // +1 for "." + coll = coll[len(coll)-remaining:] + } + return coll +} + +func (t *T) createTestClient() { + clientOpts := t.clientOpts + + if t.clientOpts == nil { + // default opts + clientOpts = options.Client().SetWriteConcern(MajorityWc).SetReadPreference(PrimaryRp) + } + + // set ServerAPIOptions to latest version if required + if clientOpts.Deployment == nil && t.clientType != Mock && clientOpts.ServerAPIOptions == nil && testContext.requireAPIVersion { + clientOpts.SetServerAPIOptions(options.ServerAPI(driver.TestServerAPIVersion)) + } + + // Setup command monitor + var customMonitor = clientOpts.Monitor + clientOpts.SetMonitor(&event.CommandMonitor{ + Started: func(ctx context.Context, cse *event.CommandStartedEvent) { + if customMonitor != nil && customMonitor.Started != nil { + customMonitor.Started(ctx, cse) + } + t.monitorLock.Lock() + defer t.monitorLock.Unlock() + t.started = append(t.started, cse) + }, + Succeeded: func(ctx context.Context, cse *event.CommandSucceededEvent) { + if customMonitor != nil && customMonitor.Succeeded != nil { + customMonitor.Succeeded(ctx, cse) + } + t.monitorLock.Lock() + defer t.monitorLock.Unlock() + t.succeeded = append(t.succeeded, cse) + }, + Failed: func(ctx context.Context, cfe *event.CommandFailedEvent) { + if customMonitor != nil && customMonitor.Failed != nil { + customMonitor.Failed(ctx, cfe) + } + t.monitorLock.Lock() + defer t.monitorLock.Unlock() + t.failed = append(t.failed, cfe) + }, + }) + // only specify connection pool monitor if no deployment is given + if clientOpts.Deployment == nil { + previousPoolMonitor := clientOpts.PoolMonitor + + clientOpts.SetPoolMonitor(&event.PoolMonitor{ + Event: func(evt *event.PoolEvent) { + if previousPoolMonitor != nil { + previousPoolMonitor.Event(evt) + } + + switch evt.Type { + case event.ConnectionCheckedOut: + atomic.AddInt64(&t.connsCheckedOut, 1) + case event.ConnectionCheckedIn: + atomic.AddInt64(&t.connsCheckedOut, -1) + } + }, + }) + } + + var err error + switch t.clientType { + case Pinned: + // pin to first mongos + pinnedHostList := []string{testContext.connString.Hosts[0]} + uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList) + t.Client, err = mongo.Connect(uriOpts, clientOpts) + case Mock: + // clear pool monitor to avoid configuration error + + clientOpts.PoolMonitor = nil + + t.mockDeployment = drivertest.NewMockDeployment() + clientOpts.Deployment = t.mockDeployment + + t.Client, err = mongo.Connect(clientOpts) + case Proxy: + t.proxyDialer = newProxyDialer() + clientOpts.SetDialer(t.proxyDialer) + + // After setting the Dialer, fall-through to the Default case to apply the correct URI + fallthrough + case Default: + // Use a different set of options to specify the URI because clientOpts may already have a URI or host seedlist + // specified. + var uriOpts *options.ClientOptions + if clientOpts.Deployment == nil { + // Only specify URI if the deployment is not set to avoid setting topology/server options along with the + // deployment. + uriOpts = options.Client().ApplyURI(testContext.connString.Original) + } + + t.Client, err = mongo.Connect(uriOpts, clientOpts) + } + if err != nil { + t.Fatalf("error creating client: %v", err) + } +} + +func (t *T) createTestCollection() { + t.DB = t.Client.Database(t.dbName) + t.createdColls = t.createdColls[:0] + + // Collections should not be explicitly created when testing against Atlas Data Lake because they already exist in + // the server with pre-seeded data. + createOnServer := (t.createCollection == nil || *t.createCollection) && !testContext.dataLake + t.Coll = t.CreateCollection(Collection{ + Name: t.collName, + CreateOpts: t.collCreateOpts, + Opts: t.collOpts, + }, createOnServer) +} + +// verifyVersionConstraints returns an error if the cluster's server version is not in the range [min, max]. Server +// versions will only be checked if they are non-empty. +func verifyVersionConstraints(min, max string) error { + if min != "" && CompareServerVersions(testContext.serverVersion, min) < 0 { + return fmt.Errorf("server version %q is lower than min required version %q", testContext.serverVersion, min) + } + if max != "" && CompareServerVersions(testContext.serverVersion, max) > 0 { + return fmt.Errorf("server version %q is higher than max version %q", testContext.serverVersion, max) + } + return nil +} + +// verifyTopologyConstraints returns an error if the cluster's topology kind does not match one of the provided +// kinds. If the topologies slice is empty, nil is returned without any additional checks. +func verifyTopologyConstraints(topologies []TopologyKind) error { + if len(topologies) == 0 { + return nil + } + + for _, topo := range topologies { + // For ShardedReplicaSet, we won't get an exact match because testContext.topoKind will be Sharded so we do an + // additional comparison with the testContext.shardedReplicaSet field. + if topo == testContext.topoKind || (topo == ShardedReplicaSet && testContext.shardedReplicaSet) { + return nil + } + } + return fmt.Errorf("topology kind %q does not match any of the required kinds %q", testContext.topoKind, topologies) +} + +func verifyServerParametersConstraints(serverParameters map[string]bson.RawValue) error { + for param, expected := range serverParameters { + actual, err := testContext.serverParameters.LookupErr(param) + if err != nil { + return fmt.Errorf("server does not support parameter %q", param) + } + if !expected.Equal(actual) { + return fmt.Errorf("mismatched values for server parameter %q; expected %s, got %s", param, expected, actual) + } + } + return nil +} + +func verifyAuthConstraint(expected *bool) error { + if expected != nil && *expected != testContext.authEnabled { + return fmt.Errorf("test requires auth value: %v, cluster auth value: %v", *expected, testContext.authEnabled) + } + return nil +} + +func verifyServerlessConstraint(expected string) error { + switch expected { + case "require": + if !testContext.serverless { + return fmt.Errorf("test requires serverless") + } + case "forbid": + if testContext.serverless { + return fmt.Errorf("test forbids serverless") + } + case "allow", "": + default: + return fmt.Errorf("invalid value for serverless: %s", expected) + } + return nil +} + +// verifyRunOnBlockConstraint returns an error if the current environment does not match the provided RunOnBlock. +func verifyRunOnBlockConstraint(rob RunOnBlock) error { + if err := verifyVersionConstraints(rob.MinServerVersion, rob.MaxServerVersion); err != nil { + return err + } + if err := verifyTopologyConstraints(rob.Topology); err != nil { + return err + } + + // Tests in the unified test format have runOn.auth to indicate whether the + // test should be run against an auth-enabled configuration. SDAM integration + // spec tests have runOn.authEnabled to indicate the same thing. Use whichever + // is set for verifyAuthConstraint(). + auth := rob.Auth + if rob.AuthEnabled != nil { + if auth != nil { + return fmt.Errorf("runOnBlock cannot specify both auth and authEnabled") + } + auth = rob.AuthEnabled + } + if err := verifyAuthConstraint(auth); err != nil { + return err + } + + if err := verifyServerlessConstraint(rob.Serverless); err != nil { + return err + } + if err := verifyServerParametersConstraints(rob.ServerParameters); err != nil { + return err + } + + if rob.CSFLE != nil { + if *rob.CSFLE && !IsCSFLEEnabled() { + return fmt.Errorf("runOnBlock requires CSFLE to be enabled. Build with the cse tag to enable") + } else if !*rob.CSFLE && IsCSFLEEnabled() { + return fmt.Errorf("runOnBlock requires CSFLE to be disabled. Build without the cse tag to disable") + } + if *rob.CSFLE { + if err := verifyVersionConstraints("4.2", ""); err != nil { + return err + } + } + } + return nil +} + +// verifyConstraints returns an error if the current environment does not match the constraints specified for the test. +func (t *T) verifyConstraints() error { + // Check constraints not specified as runOn blocks + if err := verifyVersionConstraints(t.minServerVersion, t.maxServerVersion); err != nil { + return err + } + if err := verifyTopologyConstraints(t.validTopologies); err != nil { + return err + } + if err := verifyAuthConstraint(t.auth); err != nil { + return err + } + if t.ssl != nil && *t.ssl != testContext.sslEnabled { + return fmt.Errorf("test requires ssl value: %v, cluster ssl value: %v", *t.ssl, testContext.sslEnabled) + } + if t.enterprise != nil && *t.enterprise != testContext.enterpriseServer { + return fmt.Errorf("test requires enterprise value: %v, cluster enterprise value: %v", *t.enterprise, + testContext.enterpriseServer) + } + if t.dataLake != nil && *t.dataLake != testContext.dataLake { + return fmt.Errorf("test requires cluster to be data lake: %v, cluster is data lake: %v", *t.dataLake, + testContext.dataLake) + } + if t.requireAPIVersion != nil && *t.requireAPIVersion != testContext.requireAPIVersion { + return fmt.Errorf("test requires RequireAPIVersion value: %v, local RequireAPIVersion value: %v", *t.requireAPIVersion, + testContext.requireAPIVersion) + } + + // Check runOn blocks. The test can be executed if there are no blocks or at least block matches the current test + // setup. + if len(t.runOn) == 0 { + return nil + } + + // Stop once we find a RunOnBlock that matches the current environment. Record all errors as we go because if we + // don't find any matching blocks, we want to report the comparison errors for each block. + runOnErrors := make([]error, 0, len(t.runOn)) + for _, runOn := range t.runOn { + err := verifyRunOnBlockConstraint(runOn) + if err == nil { + return nil + } + + runOnErrors = append(runOnErrors, err) + } + return fmt.Errorf("no matching RunOnBlock; comparison errors: %v", runOnErrors) +} + +func (t *T) interfaceToInt32(i interface{}) (int32, error) { + switch conv := i.(type) { + case int: + return int32(conv), nil + case int32: + return conv, nil + case int64: + return int32(conv), nil + case float64: + return int32(conv), nil + } + + return 0, fmt.Errorf("type %T cannot be converted to int32", i) +} diff --git a/drivers/mongov2/internal/mtest/options.go b/drivers/mongov2/internal/mtest/options.go new file mode 100644 index 0000000..aff188b --- /dev/null +++ b/drivers/mongov2/internal/mtest/options.go @@ -0,0 +1,283 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// TopologyKind describes the topology that a test is run on. +type TopologyKind string + +// These constants specify valid values for TopologyKind +const ( + ReplicaSet TopologyKind = "replicaset" + Sharded TopologyKind = "sharded" + Single TopologyKind = "single" + LoadBalanced TopologyKind = "load-balanced" + // ShardedReplicaSet is a special case of sharded that requires each shard to be a replica set rather than a + // standalone server. + ShardedReplicaSet TopologyKind = "sharded-replicaset" +) + +// ClientType specifies the type of Client that should be created for a test. +type ClientType int + +// These constants specify valid values for ClientType +const ( + // Default specifies a client to the connection string in the MONGODB_URI env variable with command monitoring + // enabled. + Default ClientType = iota + // Pinned specifies a client that is pinned to a single mongos in a sharded cluster. + Pinned + // Mock specifies a client that communicates with a mock deployment. + Mock + // Proxy specifies a client that proxies messages to the server and also stores parsed copies. The proxied + // messages can be retrieved via T.GetProxiedMessages or T.GetRawProxiedMessages. + Proxy +) + +var ( + falseBool = false +) + +// RunOnBlock describes a constraint for a test. +type RunOnBlock struct { + MinServerVersion string `bson:"minServerVersion"` + MaxServerVersion string `bson:"maxServerVersion"` + Topology []TopologyKind `bson:"topology"` + Serverless string `bson:"serverless"` + ServerParameters map[string]bson.RawValue `bson:"serverParameters"` + Auth *bool `bson:"auth"` + AuthEnabled *bool `bson:"authEnabled"` + CSFLE *bool `bson:"csfle"` +} + +// UnmarshalBSON implements custom BSON unmarshalling behavior for RunOnBlock because some test formats use the +// "topology" key while the unified test format uses "topologies". +func (r *RunOnBlock) UnmarshalBSON(data []byte) error { + var temp struct { + MinServerVersion string `bson:"minServerVersion"` + MaxServerVersion string `bson:"maxServerVersion"` + Topology []TopologyKind `bson:"topology"` + Topologies []TopologyKind `bson:"topologies"` + Serverless string `bson:"serverless"` + ServerParameters map[string]bson.RawValue `bson:"serverParameters"` + Auth *bool `bson:"auth"` + AuthEnabled *bool `bson:"authEnabled"` + CSFLE *bool `bson:"csfle"` + Extra map[string]interface{} `bson:",inline"` + } + if err := bson.Unmarshal(data, &temp); err != nil { + return fmt.Errorf("error unmarshalling to temporary RunOnBlock object: %w", err) + } + if len(temp.Extra) > 0 { + return fmt.Errorf("unrecognized fields for RunOnBlock: %v", temp.Extra) + } + + r.MinServerVersion = temp.MinServerVersion + r.MaxServerVersion = temp.MaxServerVersion + r.Serverless = temp.Serverless + r.ServerParameters = temp.ServerParameters + r.Auth = temp.Auth + r.AuthEnabled = temp.AuthEnabled + r.CSFLE = temp.CSFLE + + if temp.Topology != nil { + r.Topology = temp.Topology + } + if temp.Topologies != nil { + if r.Topology != nil { + return errors.New("both 'topology' and 'topologies' keys cannot be specified for a RunOnBlock") + } + + r.Topology = temp.Topologies + } + return nil +} + +// optionFunc is a function type that configures a T instance. +type optionFunc func(*T) + +// Options is the type used to configure a new T instance. +type Options struct { + optFuncs []optionFunc +} + +// NewOptions creates an empty Options instance. +func NewOptions() *Options { + return &Options{} +} + +// CollectionCreateOptions sets the options to pass to Database.CreateCollection() when creating a collection for a test. +func (op *Options) CollectionCreateOptions(opts *options.CreateCollectionOptionsBuilder) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.collCreateOpts = opts + }) + return op +} + +// CollectionOptions sets the options to use when creating a collection for a test. +func (op *Options) CollectionOptions(opts *options.CollectionOptionsBuilder) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.collOpts = opts + }) + return op +} + +// ClientOptions sets the options to use when creating a client for a test. +func (op *Options) ClientOptions(opts *options.ClientOptions) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.clientOpts = opts + }) + return op +} + +// CreateClient specifies whether or not a client should be created for a test. This should be set to false when running +// a test that only runs other tests. +func (op *Options) CreateClient(create bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.createClient = &create + }) + return op +} + +// CreateCollection specifies whether or not a collection should be created for a test. The default value is true. +func (op *Options) CreateCollection(create bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.createCollection = &create + }) + return op +} + +// ShareClient specifies whether or not a test should pass its client down to sub-tests. This should be set when calling +// New() if the inheriting behavior is desired. This option must not be used if the test accesses command monitoring +// events. +func (op *Options) ShareClient(share bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.shareClient = &share + }) + return op +} + +// CollectionName specifies the name for the collection for the test. +func (op *Options) CollectionName(collName string) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.collName = collName + }) + return op +} + +// DatabaseName specifies the name of the database for the test. +func (op *Options) DatabaseName(dbName string) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.dbName = dbName + }) + return op +} + +// ClientType specifies the type of client that should be created for a test. This option will be propagated to all +// sub-tests. If the provided ClientType is Proxy, the SSL(false) option will be also be added because the internal +// proxy dialer and connection types do not support SSL. +func (op *Options) ClientType(ct ClientType) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.clientType = ct + + if ct == Proxy { + t.ssl = &falseBool + } + }) + return op +} + +// MockResponses specifies the responses returned by a mock deployment. This should only be used if the current test +// is being run with MockDeployment(true). Responses can also be added after a sub-test has already been created. +func (op *Options) MockResponses(responses ...bson.D) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.mockResponses = responses + }) + return op +} + +// RunOn specifies run-on blocks used to determine if a test should run. If a test's environment meets at least one of the +// given constraints, it will be run. Otherwise, it will be skipped. +func (op *Options) RunOn(blocks ...RunOnBlock) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.runOn = append(t.runOn, blocks...) + }) + return op +} + +// MinServerVersion specifies the minimum server version for the test. +func (op *Options) MinServerVersion(version string) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.minServerVersion = version + }) + return op +} + +// MaxServerVersion specifies the maximum server version for the test. +func (op *Options) MaxServerVersion(version string) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.maxServerVersion = version + }) + return op +} + +// Topologies specifies a list of topologies that the test can run on. +func (op *Options) Topologies(topos ...TopologyKind) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.validTopologies = topos + }) + return op +} + +// Auth specifies whether or not auth should be enabled for this test to run. By default, a test will run regardless +// of whether or not auth is enabled. +func (op *Options) Auth(auth bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.auth = &auth + }) + return op +} + +// SSL specifies whether or not SSL should be enabled for this test to run. By default, a test will run regardless +// of whether or not SSL is enabled. +func (op *Options) SSL(ssl bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.ssl = &ssl + }) + return op +} + +// Enterprise specifies whether or not this test should only be run on enterprise server variants. Defaults to false. +func (op *Options) Enterprise(ent bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.enterprise = &ent + }) + return op +} + +// AtlasDataLake specifies whether this test should only be run against Atlas Data Lake servers. Defaults to false. +func (op *Options) AtlasDataLake(adl bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.dataLake = &adl + }) + return op +} + +// RequireAPIVersion specifies whether this test should only be run when REQUIRE_API_VERSION is true. Defaults to false. +func (op *Options) RequireAPIVersion(rav bool) *Options { + op.optFuncs = append(op.optFuncs, func(t *T) { + t.requireAPIVersion = &rav + }) + return op +} diff --git a/drivers/mongov2/internal/mtest/proxy_dialer.go b/drivers/mongov2/internal/mtest/proxy_dialer.go new file mode 100644 index 0000000..7f17dbb --- /dev/null +++ b/drivers/mongov2/internal/mtest/proxy_dialer.go @@ -0,0 +1,186 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +// ProxyMessage represents a sent/received pair of parsed wire messages. +type ProxyMessage struct { + ServerAddress string + CommandName string + Sent *SentMessage + Received *ReceivedMessage +} + +// proxyDialer is a ContextDialer implementation that wraps a net.Dialer and records the messages sent and received +// using connections created through it. +type proxyDialer struct { + *net.Dialer + sync.Mutex + + messages []*ProxyMessage + // sentMap temporarily stores the message sent to the server using the requestID so it can map requests to their + // responses. + sentMap sync.Map + // addressTranslations maps dialed addresses to the remote addresses reported by the created connections if they + // differ. This can happen if a connection is dialed to a host name, in which case the reported remote address will + // be the resolved IP address. + addressTranslations sync.Map +} + +var _ options.ContextDialer = (*proxyDialer)(nil) + +func newProxyDialer() *proxyDialer { + return &proxyDialer{ + Dialer: &net.Dialer{Timeout: 30 * time.Second}, + } +} + +func newProxyErrorWithWireMsg(wm []byte, err error) error { + return fmt.Errorf("proxy error for wiremessage %v: %w", wm, err) +} + +// DialContext creates a new proxyConnection. +func (p *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + netConn, err := p.Dialer.DialContext(ctx, network, address) + if err != nil { + return netConn, err + } + + // If the connection's remote address does not match the dialed address, store it in the translations map for + // future look-up. Use the remote address as they key because that's what we'll have access to in the connection's + // Read/Write functions. + if remoteAddress := netConn.RemoteAddr().String(); remoteAddress != address { + p.addressTranslations.Store(remoteAddress, address) + } + + proxy := &proxyConn{ + Conn: netConn, + dialer: p, + } + return proxy, nil +} + +func (p *proxyDialer) storeSentMessage(wm []byte) error { + p.Lock() + defer p.Unlock() + + // Create a copy of the wire message so it can be parsed/stored and will not be affected if the wm slice is + // changed by the driver. + wmCopy := copyBytes(wm) + parsed, err := parseSentMessage(wmCopy) + if err != nil { + return err + } + p.sentMap.Store(parsed.RequestID, parsed) + return nil +} + +func (p *proxyDialer) storeReceivedMessage(wm []byte, addr string) error { + p.Lock() + defer p.Unlock() + + serverAddress := addr + if translated, ok := p.addressTranslations.Load(addr); ok { + serverAddress = translated.(string) + } + + // Create a copy of the wire message so it can be parsed/stored and will not be affected if the wm slice is + // changed by the driver. Parse the incoming message and get the corresponding outgoing message. + wmCopy := copyBytes(wm) + parsed, err := parseReceivedMessage(wmCopy) + if err != nil { + return err + } + mapValue, ok := p.sentMap.Load(parsed.ResponseTo) + if !ok { + return errors.New("no sent message found") + } + sent := mapValue.(*SentMessage) + p.sentMap.Delete(parsed.ResponseTo) + + // Store the parsed message pair. + msgPair := &ProxyMessage{ + // The command name is always the first key in the command document. + CommandName: sent.Command.Index(0).Key(), + ServerAddress: serverAddress, + Sent: sent, + Received: parsed, + } + p.messages = append(p.messages, msgPair) + return nil +} + +// Messages returns a slice of proxied messages. This slice is a copy of the messages proxied so far and will not be +// updated for messages proxied after this call. +func (p *proxyDialer) Messages() []*ProxyMessage { + p.Lock() + defer p.Unlock() + + copiedMessages := make([]*ProxyMessage, len(p.messages)) + copy(copiedMessages, p.messages) + return copiedMessages +} + +// proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored +// in the associated proxyDialer and are forwarded over the wrapped connection. Errors encountered when parsing and +// storing wire messages are wrapped to add context, while errors returned from the underlying network connection are +// forwarded without wrapping. +type proxyConn struct { + net.Conn + dialer *proxyDialer +} + +// Write stores the given message in the proxyDialer associated with this connection and forwards the message to the +// server. +func (pc *proxyConn) Write(wm []byte) (n int, err error) { + if err := pc.dialer.storeSentMessage(wm); err != nil { + wrapped := fmt.Errorf("error storing sent message: %w", err) + return 0, newProxyErrorWithWireMsg(wm, wrapped) + } + + return pc.Conn.Write(wm) +} + +// Read reads the message from the server into the given buffer and stores the read message in the proxyDialer +// associated with this connection. +func (pc *proxyConn) Read(buffer []byte) (int, error) { + n, err := pc.Conn.Read(buffer) + if err != nil { + return n, err + } + + // The driver reads wire messages in two phases: a four-byte read to get the length of the incoming wire message + // and a (length-4) byte read to get the message itself. There's nothing to be stored during the initial four-byte + // read because we can calculate the length from the rest of the message. + if len(buffer) == 4 { + return 4, nil + } + + // The buffer contains the entire wire message except for the length bytes. Re-create the full message by appending + // buffer to the end of a four-byte slice and using UpdateLength to set the length bytes. + idx, wm := bsoncore.ReserveLength(nil) + wm = append(wm, buffer...) + wm = bsoncore.UpdateLength(wm, idx, int32(len(wm[idx:]))) + + if err := pc.dialer.storeReceivedMessage(wm, pc.RemoteAddr().String()); err != nil { + wrapped := fmt.Errorf("error storing received message: %w", err) + return 0, newProxyErrorWithWireMsg(wm, wrapped) + } + + return n, nil +} diff --git a/drivers/mongov2/internal/mtest/received_message.go b/drivers/mongov2/internal/mtest/received_message.go new file mode 100644 index 0000000..91807a6 --- /dev/null +++ b/drivers/mongov2/internal/mtest/received_message.go @@ -0,0 +1,124 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" +) + +// ReceivedMessage represents a message received from the server. +type ReceivedMessage struct { + ResponseTo int32 + RawMessage wiremessage.WireMessage + Response bsoncore.Document +} + +type receivedMsgParseFn func([]byte) (*ReceivedMessage, error) + +func getReceivedMessageParser(opcode wiremessage.OpCode) (receivedMsgParseFn, bool) { + switch opcode { + case wiremessage.OpReply: + return parseOpReply, true + case wiremessage.OpMsg: + return parseReceivedOpMsg, true + case wiremessage.OpCompressed: + return parseReceivedOpCompressed, true + default: + return nil, false + } +} + +func parseReceivedMessage(wm []byte) (*ReceivedMessage, error) { + // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. + _, _, responseTo, opcode, remaining, ok := wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("failed to read wiremessage header") + } + + parseFn, ok := getReceivedMessageParser(opcode) + if !ok { + return nil, fmt.Errorf("unknown opcode: %s", opcode) + } + received, err := parseFn(remaining) + if err != nil { + return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %w", opcode, err) + } + + received.ResponseTo = responseTo + received.RawMessage = wm + return received, nil +} + +func parseOpReply(wm []byte) (*ReceivedMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadReplyFlags(wm); !ok { + return nil, errors.New("failed to read reply flags") + } + if _, wm, ok = wiremessage.ReadReplyCursorID(wm); !ok { + return nil, errors.New("failed to read cursor ID") + } + if _, wm, ok = wiremessage.ReadReplyStartingFrom(wm); !ok { + return nil, errors.New("failed to read starting from") + } + if _, wm, ok = wiremessage.ReadReplyNumberReturned(wm); !ok { + return nil, errors.New("failed to read number returned") + } + + var replyDocuments []bsoncore.Document + replyDocuments, wm, ok = wiremessage.ReadReplyDocuments(wm) + if !ok { + return nil, errors.New("failed to read reply documents") + } + if len(replyDocuments) == 0 { + return nil, errors.New("no documents in response") + } + + rm := &ReceivedMessage{ + Response: replyDocuments[0], + } + return rm, nil +} + +func parseReceivedOpMsg(wm []byte) (*ReceivedMessage, error) { + var ok bool + var err error + + if _, wm, ok = wiremessage.ReadMsgFlags(wm); !ok { + return nil, errors.New("failed to read flags") + } + + if wm, err = assertMsgSectionType(wm, wiremessage.SingleDocument); err != nil { + return nil, fmt.Errorf("error verifying section type for response document: %w", err) + } + + response, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm) + if !ok { + return nil, errors.New("failed to read response document") + } + rm := &ReceivedMessage{ + Response: response, + } + return rm, nil +} + +func parseReceivedOpCompressed(wm []byte) (*ReceivedMessage, error) { + originalOpcode, wm, err := parseOpCompressed(wm) + if err != nil { + return nil, err + } + + parser, ok := getReceivedMessageParser(originalOpcode) + if !ok { + return nil, fmt.Errorf("unknown original opcode %v", originalOpcode) + } + return parser(wm) +} diff --git a/drivers/mongov2/internal/mtest/sent_message.go b/drivers/mongov2/internal/mtest/sent_message.go new file mode 100644 index 0000000..5be6cee --- /dev/null +++ b/drivers/mongov2/internal/mtest/sent_message.go @@ -0,0 +1,195 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" +) + +// SentMessage represents a message sent by the driver to the server. +type SentMessage struct { + RequestID int32 + RawMessage wiremessage.WireMessage + Command bsoncore.Document + OpCode wiremessage.OpCode + + // The $readPreference document. This is separated into its own field even though it's included in the larger + // command document in both OP_QUERY and OP_MSG because OP_QUERY separates the command into a $query sub-document + // if there is a read preference. To unify OP_QUERY and OP_MSG, we pull this out into a separate field and set + // the Command field to the $query sub-document. + ReadPreference bsoncore.Document + + // The documents sent for an insert, update, or delete command. This is separated into its own field because it's + // sent as part of the command document in OP_QUERY and as a document sequence outside the command document in + // OP_MSG. + Batch *bsoncore.Iterator +} + +type sentMsgParseFn func([]byte) (*SentMessage, error) + +func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { + switch opcode { + case wiremessage.OpQuery: + return parseOpQuery, true + case wiremessage.OpMsg: + return parseSentOpMsg, true + case wiremessage.OpCompressed: + return parseSentOpCompressed, true + default: + return nil, false + } +} + +func parseOpQuery(wm []byte) (*SentMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok { + return nil, errors.New("failed to read query flags") + } + if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok { + return nil, errors.New("failed to read full collection name") + } + if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok { + return nil, errors.New("failed to read number to skip") + } + if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok { + return nil, errors.New("failed to read number to return") + } + + query, wm, ok := wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("failed to read query") + } + + // If there is no read preference document, the command document is query. + // Otherwise, query is in the format {$query: , $readPreference: }. + commandDoc := query + var rpDoc bsoncore.Document + + dollarQueryVal, err := query.LookupErr("$query") + if err == nil { + commandDoc = dollarQueryVal.Document() + + rpVal, err := query.LookupErr("$readPreference") + if err != nil { + return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query) + } + rpDoc = rpVal.Document() + } + + // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command + // document. Pull these sequences out into an ArrayStyle DocumentSequence. + var batch *bsoncore.Iterator + cmdElems, _ := commandDoc.Elements() + for _, elem := range cmdElems { + switch elem.Key() { + case "documents", "updates", "deletes": + batch = &bsoncore.Iterator{ + List: elem.Value().Array(), + } + } + if batch != nil { + // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. + break + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + Batch: batch, + } + return sm, nil +} + +func parseSentMessage(wm []byte) (*SentMessage, error) { + // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. + _, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("failed to read wiremessage header") + } + + parseFn, ok := getSentMessageParser(opcode) + if !ok { + return nil, fmt.Errorf("unknown opcode: %v", opcode) + } + sent, err := parseFn(remaining) + if err != nil { + return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %w", opcode, err) + } + + sent.RequestID = requestID + sent.RawMessage = wm + sent.OpCode = opcode + return sent, nil +} + +func parseSentOpMsg(wm []byte) (*SentMessage, error) { + var ok bool + var err error + + if _, wm, ok = wiremessage.ReadMsgFlags(wm); !ok { + return nil, errors.New("failed to read flags") + } + + if wm, err = assertMsgSectionType(wm, wiremessage.SingleDocument); err != nil { + return nil, fmt.Errorf("error verifying section type for command document: %w", err) + } + + var commandDoc bsoncore.Document + commandDoc, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm) + if !ok { + return nil, errors.New("failed to read command document") + } + + var rpDoc bsoncore.Document + if rpVal, err := commandDoc.LookupErr("$readPreference"); err == nil { + rpDoc = rpVal.Document() + } + + var batch *bsoncore.Iterator + if len(wm) != 0 { + // If there are bytes remaining in the wire message, they must correspond to a DocumentSequence section. + if wm, err = assertMsgSectionType(wm, wiremessage.DocumentSequence); err != nil { + return nil, fmt.Errorf("error verifying section type for document sequence: %w", err) + } + + var data []byte + _, data, wm, ok = wiremessage.ReadMsgSectionRawDocumentSequence(wm) + if !ok { + return nil, errors.New("failed to read document sequence") + } + + batch = &bsoncore.Iterator{ + List: data, + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + Batch: batch, + } + return sm, nil +} + +func parseSentOpCompressed(wm []byte) (*SentMessage, error) { + originalOpcode, wm, err := parseOpCompressed(wm) + if err != nil { + return nil, err + } + + parser, ok := getSentMessageParser(originalOpcode) + if !ok { + return nil, fmt.Errorf("unknown original opcode %v", originalOpcode) + } + return parser(wm) +} diff --git a/drivers/mongov2/internal/mtest/setup.go b/drivers/mongov2/internal/mtest/setup.go new file mode 100644 index 0000000..f2a0d30 --- /dev/null +++ b/drivers/mongov2/internal/mtest/setup.go @@ -0,0 +1,376 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "context" + "errors" + "fmt" + "math" + "os" + "strconv" + "strings" + "time" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/integtest" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/readpref" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" +) + +const ( + // TestDb specifies the name of default test database. + TestDb = "test" +) + +// testContext holds the global context for the integration tests. The testContext members should only be initialized +// once during the global setup in TestMain. These variables should only be accessed indirectly through MongoTest +// instances. +var testContext struct { + connString *connstring.ConnString + topo *topology.Topology + topoKind TopologyKind + // shardedReplicaSet will be true if we're connected to a sharded cluster and each shard is backed by a replica set. + // We track this as a separate boolean rather than setting topoKind to ShardedReplicaSet because a general + // "Sharded" constraint in a test should match both Sharded and ShardedReplicaSet. + shardedReplicaSet bool + client *mongo.Client // client used for setup and teardown + serverVersion string + authEnabled bool + sslEnabled bool + enterpriseServer bool + dataLake bool + requireAPIVersion bool + serverParameters bson.Raw + singleMongosLoadBalancerURI string + multiMongosLoadBalancerURI string + serverless bool +} + +func setupClient(opts *options.ClientOptions) (*mongo.Client, error) { + wcMajority := writeconcern.Majority() + // set ServerAPIOptions to latest version if required + if opts.ServerAPIOptions == nil && testContext.requireAPIVersion { + opts.SetServerAPIOptions(options.ServerAPI(driver.TestServerAPIVersion)) + } + // for sharded clusters, pin to one host. Due to how the cache is implemented on 4.0 and 4.2, behavior + // can be inconsistent when multiple mongoses are used + return mongo.Connect(opts.SetWriteConcern(wcMajority).SetHosts(opts.Hosts[:1])) +} + +// Setup initializes the current testing context. +// This function must only be called one time and must be called before any tests run. +func Setup(setupOpts ...*SetupOptions) error { + opts := NewSetupOptions() + for _, opt := range setupOpts { + if opt == nil { + continue + } + if opt.URI != nil { + opts.URI = opt.URI + } + } + + var uri string + var err error + + switch { + case opts.URI != nil: + uri = *opts.URI + default: + var err error + uri, err = integtest.MongoDBURI() + if err != nil { + return fmt.Errorf("error getting uri: %w", err) + } + } + + testContext.connString, err = connstring.ParseAndValidate(uri) + if err != nil { + return fmt.Errorf("error parsing and validating connstring: %w", err) + } + + testContext.dataLake = os.Getenv("ATLAS_DATA_LAKE_INTEGRATION_TEST") == "true" + testContext.requireAPIVersion = os.Getenv("REQUIRE_API_VERSION") == "true" + + clientOpts := options.Client().ApplyURI(uri) + integtest.AddTestServerAPIVersion(clientOpts) + + cfg, err := topology.NewConfig(clientOpts, nil) + if err != nil { + return fmt.Errorf("error constructing topology config: %w", err) + } + + testContext.topo, err = topology.New(cfg) + if err != nil { + return fmt.Errorf("error creating topology: %w", err) + } + if err = testContext.topo.Connect(); err != nil { + return fmt.Errorf("error connecting topology: %w", err) + } + + testContext.client, err = setupClient(options.Client().ApplyURI(uri)) + if err != nil { + return fmt.Errorf("error connecting test client: %w", err) + } + + pingCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := testContext.client.Ping(pingCtx, readpref.Primary()); err != nil { + return fmt.Errorf("ping error: %w; make sure the deployment is running on URI %v", err, + testContext.connString.Original) + } + + if testContext.serverVersion, err = getServerVersion(); err != nil { + return fmt.Errorf("error getting server version: %w", err) + } + + switch testContext.topo.Kind() { + case description.TopologyKindSingle: + testContext.topoKind = Single + case description.TopologyKindReplicaSet, description.TopologyKindReplicaSetWithPrimary, description.TopologyKindReplicaSetNoPrimary: + testContext.topoKind = ReplicaSet + case description.TopologyKindSharded: + testContext.topoKind = Sharded + case description.TopologyKindLoadBalanced: + testContext.topoKind = LoadBalanced + default: + return fmt.Errorf("could not detect topology kind; current topology: %s", testContext.topo.String()) + } + + // If we're connected to a sharded cluster, determine if the cluster is backed by replica sets. + if testContext.topoKind == Sharded { + // Run a find against config.shards and get each document in the collection. + cursor, err := testContext.client.Database("config").Collection("shards").Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("error running find against config.shards: %w", err) + } + defer cursor.Close(context.Background()) + + var shards []struct { + Host string `bson:"host"` + } + if err := cursor.All(context.Background(), &shards); err != nil { + return fmt.Errorf("error getting results find against config.shards: %w", err) + } + + // Each document's host field will contain a single hostname if the shard is a standalone. If it's a replica + // set, the host field will be in the format "replicaSetName/host1,host2,...". Therefore, we can determine that + // the shard is a standalone if the "/" character isn't present. + var foundStandalone bool + for _, shard := range shards { + if !strings.Contains(shard.Host, "/") { + foundStandalone = true + break + } + } + if !foundStandalone { + testContext.shardedReplicaSet = true + } + } + + // For non-serverless, load balanced clusters, retrieve the required LB URIs and add additional information (e.g. TLS options) to + // them if necessary. + testContext.serverless = os.Getenv("SERVERLESS") == "serverless" + if !testContext.serverless && testContext.topoKind == LoadBalanced { + singleMongosURI := os.Getenv("SINGLE_MONGOS_LB_URI") + if singleMongosURI == "" { + return errors.New("SINGLE_MONGOS_LB_URI must be set when running against load balanced clusters") + } + testContext.singleMongosLoadBalancerURI, err = addNecessaryParamsToURI(singleMongosURI) + if err != nil { + return fmt.Errorf("error getting single mongos load balancer uri: %w", err) + } + + multiMongosURI := os.Getenv("MULTI_MONGOS_LB_URI") + if multiMongosURI == "" { + return errors.New("MULTI_MONGOS_LB_URI must be set when running against load balanced clusters") + } + testContext.multiMongosLoadBalancerURI, err = addNecessaryParamsToURI(multiMongosURI) + if err != nil { + return fmt.Errorf("error getting multi mongos load balancer uri: %w", err) + } + } + + testContext.authEnabled = os.Getenv("AUTH") == "auth" + testContext.sslEnabled = os.Getenv("SSL") == "ssl" + biRes, err := testContext.client.Database("admin").RunCommand(context.Background(), bson.D{{"buildInfo", 1}}).Raw() + if err != nil { + return fmt.Errorf("buildInfo error: %w", err) + } + modulesRaw, err := biRes.LookupErr("modules") + if err == nil { + // older server versions don't report "modules" field in buildInfo result + modules, _ := modulesRaw.Array().Values() + for _, module := range modules { + if module.StringValue() == "enterprise" { + testContext.enterpriseServer = true + break + } + } + } + + // Get server parameters if test is not running against ADL; ADL does not have "getParameter" command. + if !testContext.dataLake { + db := testContext.client.Database("admin") + testContext.serverParameters, err = db.RunCommand(context.Background(), bson.D{{"getParameter", "*"}}).Raw() + if err != nil { + return fmt.Errorf("error getting serverParameters: %w", err) + } + } + return nil +} + +// Teardown cleans up resources initialized by Setup. +// This function must be called once after all tests have finished running. +func Teardown() error { + // Dropping the test database causes an error against Atlas Data Lake. + if !testContext.dataLake { + if err := testContext.client.Database(TestDb).Drop(context.Background()); err != nil { + return fmt.Errorf("error dropping test database: %w", err) + } + } + if err := testContext.client.Disconnect(context.Background()); err != nil { + return fmt.Errorf("error disconnecting test client: %w", err) + } + if err := testContext.topo.Disconnect(context.Background()); err != nil { + return fmt.Errorf("error disconnecting test topology: %w", err) + } + return nil +} + +func getServerVersion() (string, error) { + var serverStatus bson.Raw + err := testContext.client.Database(TestDb).RunCommand( + context.Background(), + bson.D{{"buildInfo", 1}}, + ).Decode(&serverStatus) + if err != nil { + return "", err + } + + version, err := serverStatus.LookupErr("version") + if err != nil { + return "", errors.New("no version string in serverStatus response") + } + + return version.StringValue(), nil +} + +// addOptions appends connection string options to a URI. +func addOptions(uri string, opts ...string) string { + if !strings.ContainsRune(uri, '?') { + if uri[len(uri)-1] != '/' { + uri += "/" + } + + uri += "?" + } else { + uri += "&" + } + + for _, opt := range opts { + uri += opt + } + + return uri +} + +// addTLSConfig checks for the environmental variable indicating that the tests are being run +// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration. +func addTLSConfig(uri string) string { + if os.Getenv("SSL") == "ssl" { + uri = addOptions(uri, "ssl=", "true") + } + caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE") + if len(caFile) == 0 { + return uri + } + + return addOptions(uri, "sslCertificateAuthorityFile=", caFile) +} + +// addCompressors checks for the environment variable indicating that the tests are being run with compression +// enabled. If so, it returns a new URI with the necessary configuration +func addCompressors(uri string) string { + comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR") + if len(comp) == 0 { + return uri + } + + return addOptions(uri, "compressors=", comp) +} + +func addServerlessAuthCredentials(uri string) (string, error) { + if os.Getenv("SERVERLESS") != "serverless" { + return uri, nil + } + user := os.Getenv("SERVERLESS_ATLAS_USER") + if user == "" { + return "", fmt.Errorf("serverless expects SERVERLESS_ATLAS_USER to be set") + } + password := os.Getenv("SERVERLESS_ATLAS_PASSWORD") + if password == "" { + return "", fmt.Errorf("serverless expects SERVERLESS_ATLAS_PASSWORD to be set") + } + + var scheme string + // remove the scheme + switch { + case strings.HasPrefix(uri, "mongodb+srv://"): + scheme = "mongodb+srv://" + case strings.HasPrefix(uri, "mongodb://"): + scheme = "mongodb://" + default: + return "", errors.New(`scheme must be "mongodb" or "mongodb+srv"`) + } + + uri = scheme + user + ":" + password + "@" + uri[len(scheme):] + return uri, nil +} + +func addNecessaryParamsToURI(uri string) (string, error) { + uri = addTLSConfig(uri) + uri = addCompressors(uri) + return addServerlessAuthCredentials(uri) +} + +// CompareServerVersions compares two version number strings (i.e. positive integers separated by +// periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is +// considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11. +// +// Returns a positive int if version1 is greater than version2, a negative int if version1 is less +// than version2, and 0 if version1 is equal to version2. +func CompareServerVersions(v1 string, v2 string) int { + n1 := strings.Split(v1, ".") + n2 := strings.Split(v2, ".") + + for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ { + i1, err := strconv.Atoi(n1[i]) + if err != nil { + return 1 + } + + i2, err := strconv.Atoi(n2[i]) + if err != nil { + return -1 + } + + difference := i1 - i2 + if difference != 0 { + return difference + } + } + + return 0 +} diff --git a/drivers/mongov2/internal/mtest/setup_options.go b/drivers/mongov2/internal/mtest/setup_options.go new file mode 100644 index 0000000..76a3c27 --- /dev/null +++ b/drivers/mongov2/internal/mtest/setup_options.go @@ -0,0 +1,25 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +// SetupOptions is the type used to configure mtest setup +type SetupOptions struct { + // Specifies the URI to connect to. Defaults to URI based on the environment variables MONGODB_URI, + // MONGO_GO_DRIVER_CA_FILE, and MONGO_GO_DRIVER_COMPRESSOR + URI *string +} + +// NewSetupOptions creates an empty SetupOptions struct +func NewSetupOptions() *SetupOptions { + return &SetupOptions{} +} + +// SetURI sets the uri to connect to +func (so *SetupOptions) SetURI(uri string) *SetupOptions { + so.URI = &uri + return so +} diff --git a/drivers/mongov2/internal/mtest/wiremessage_helpers.go b/drivers/mongov2/internal/mtest/wiremessage_helpers.go new file mode 100644 index 0000000..8252c6b --- /dev/null +++ b/drivers/mongov2/internal/mtest/wiremessage_helpers.go @@ -0,0 +1,67 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "errors" + "fmt" + + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" +) + +func copyBytes(original []byte) []byte { + newSlice := make([]byte, len(original)) + copy(newSlice, original) + return newSlice +} + +// assertMsgSectionType asserts that the next section type in the OP_MSG wire message is equal to the provided type. +// It returns the remainder of the wire message and an error if the section type could not be read or was not equal +// to the expected type. +func assertMsgSectionType(wm []byte, expected wiremessage.SectionType) ([]byte, error) { + var actual wiremessage.SectionType + var ok bool + + actual, wm, ok = wiremessage.ReadMsgSectionType(wm) + if !ok { + return wm, errors.New("failed to read section type") + } + if expected != actual { + return wm, fmt.Errorf("unexpected section type %v; expected %v", actual, expected) + } + return wm, nil +} + +func parseOpCompressed(wm []byte) (wiremessage.OpCode, []byte, error) { + // Store the original opcode to forward to another parser later. + originalOpcode, wm, ok := wiremessage.ReadCompressedOriginalOpCode(wm) + if !ok { + return originalOpcode, nil, errors.New("failed to read original opcode") + } + + uncompressedSize, wm, ok := wiremessage.ReadCompressedUncompressedSize(wm) + if !ok { + return originalOpcode, nil, errors.New("failed to read uncompressed size") + } + + compressorID, compressedMsg, ok := wiremessage.ReadCompressedCompressorID(wm) + if !ok { + return originalOpcode, nil, errors.New("failed to read compressor ID") + } + + opts := driver.CompressionOpts{ + Compressor: compressorID, + UncompressedSize: uncompressedSize, + } + decompressed, err := driver.DecompressPayload(compressedMsg, opts) + if err != nil { + return originalOpcode, nil, fmt.Errorf("error decompressing payload: %w", err) + } + + return originalOpcode, decompressed, nil +} diff --git a/drivers/mongov2/internal/require/require.go b/drivers/mongov2/internal/require/require.go new file mode 100644 index 0000000..aa6ba69 --- /dev/null +++ b/drivers/mongov2/internal/require/require.go @@ -0,0 +1,819 @@ +// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/require/require.go + +// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved. +// Use of this source code is governed by an MIT-style license that can be found in +// the THIRD-PARTY-NOTICES file. + +package require + +import ( + time "time" + + assert "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" +) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Errorf(format string, args ...interface{}) + FailNow() +} + +type tHelper interface { + Helper() +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Contains(t, "Hello World", "World") +// assert.Contains(t, ["Hello", "World"], "World") +// assert.Contains(t, {"Hello": "World"}, "Hello") +func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Contains(t, s, contains, msgAndArgs...) { + return + } + t.FailNow() +} + +// Containsf asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Containsf(t, s, contains, msg, args...) { + return + } + t.FailNow() +} + +// ElementsMatch asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) +func ElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.ElementsMatch(t, listA, listB, msgAndArgs...) { + return + } + t.FailNow() +} + +// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should match. +// +// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.ElementsMatchf(t, listA, listB, msg, args...) { + return + } + t.FailNow() +} + +// Equal asserts that two objects are equal. +// +// assert.Equal(t, 123, 123) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equal(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Equal(t, expected, actual, msgAndArgs...) { + return + } + t.FailNow() +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualError(t, err, expectedErrorString) +func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.EqualError(t, theError, errString, msgAndArgs...) { + return + } + t.FailNow() +} + +// EqualErrorf asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.EqualErrorf(t, theError, errString, msg, args...) { + return + } + t.FailNow() +} + +// EqualValues asserts that two objects are equal or convertible to the same types +// and equal. +// +// assert.EqualValues(t, uint32(123), int32(123)) +func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.EqualValues(t, expected, actual, msgAndArgs...) { + return + } + t.FailNow() +} + +// EqualValuesf asserts that two objects are equal or convertible to the same types +// and equal. +// +// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.EqualValuesf(t, expected, actual, msg, args...) { + return + } + t.FailNow() +} + +// Equalf asserts that two objects are equal. +// +// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). Function equality +// cannot be determined and will always fail. +func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Equalf(t, expected, actual, msg, args...) { + return + } + t.FailNow() +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err) { +// assert.Equal(t, expectedError, err) +// } +func Error(t TestingT, err error, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Error(t, err, msgAndArgs...) { + return + } + t.FailNow() +} + +// ErrorContains asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContains(t, err, expectedErrorSubString) +func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.ErrorContains(t, theError, contains, msgAndArgs...) { + return + } + t.FailNow() +} + +// ErrorContainsf asserts that a function returned an error (i.e. not `nil`) +// and that the error contains the specified substring. +// +// actualObj, err := SomeFunction() +// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.ErrorContainsf(t, theError, contains, msg, args...) { + return + } + t.FailNow() +} + +// Errorf asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Errorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedErrorf, err) +// } +func Errorf(t TestingT, err error, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Errorf(t, err, msg, args...) { + return + } + t.FailNow() +} + +// Eventually asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Eventually(t, condition, waitFor, tick, msgAndArgs...) { + return + } + t.FailNow() +} + +// Eventuallyf asserts that given condition will be met in waitFor time, +// periodically checking target function each tick. +// +// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Eventuallyf(t, condition, waitFor, tick, msg, args...) { + return + } + t.FailNow() +} + +// Fail reports a failure through +func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Fail(t, failureMessage, msgAndArgs...) { + return + } + t.FailNow() +} + +// FailNow fails test +func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.FailNow(t, failureMessage, msgAndArgs...) { + return + } + t.FailNow() +} + +// FailNowf fails test +func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.FailNowf(t, failureMessage, msg, args...) { + return + } + t.FailNow() +} + +// Failf reports a failure through +func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Failf(t, failureMessage, msg, args...) { + return + } + t.FailNow() +} + +// False asserts that the specified value is false. +// +// assert.False(t, myBool) +func False(t TestingT, value bool, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.False(t, value, msgAndArgs...) { + return + } + t.FailNow() +} + +// Falsef asserts that the specified value is false. +// +// assert.Falsef(t, myBool, "error message %s", "formatted") +func Falsef(t TestingT, value bool, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Falsef(t, value, msg, args...) { + return + } + t.FailNow() +} + +// Greater asserts that the first element is greater than the second +// +// assert.Greater(t, 2, 1) +// assert.Greater(t, float64(2), float64(1)) +// assert.Greater(t, "b", "a") +func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Greater(t, e1, e2, msgAndArgs...) { + return + } + t.FailNow() +} + +// GreaterOrEqual asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqual(t, 2, 1) +// assert.GreaterOrEqual(t, 2, 2) +// assert.GreaterOrEqual(t, "b", "a") +// assert.GreaterOrEqual(t, "b", "b") +func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.GreaterOrEqual(t, e1, e2, msgAndArgs...) { + return + } + t.FailNow() +} + +// GreaterOrEqualf asserts that the first element is greater than or equal to the second +// +// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.GreaterOrEqualf(t, e1, e2, msg, args...) { + return + } + t.FailNow() +} + +// Greaterf asserts that the first element is greater than the second +// +// assert.Greaterf(t, 2, 1, "error message %s", "formatted") +// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Greaterf(t, e1, e2, msg, args...) { + return + } + t.FailNow() +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +func InDelta(t TestingT, expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.InDelta(t, expected, actual, delta, msgAndArgs...) { + return + } + t.FailNow() +} + +// InDeltaf asserts that the two numerals are within delta of each other. +// +// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.InDeltaf(t, expected, actual, delta, msg, args...) { + return + } + t.FailNow() +} + +// IsType asserts that the specified objects are of the same type. +func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.IsType(t, expectedType, object, msgAndArgs...) { + return + } + t.FailNow() +} + +// IsTypef asserts that the specified objects are of the same type. +func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.IsTypef(t, expectedType, object, msg, args...) { + return + } + t.FailNow() +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// assert.Len(t, mySlice, 3) +func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Len(t, object, length, msgAndArgs...) { + return + } + t.FailNow() +} + +// Lenf asserts that the specified object has specific length. +// Lenf also fails if the object has a type that len() not accept. +// +// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Lenf(t, object, length, msg, args...) { + return + } + t.FailNow() +} + +// Less asserts that the first element is less than the second +// +// assert.Less(t, 1, 2) +// assert.Less(t, float64(1), float64(2)) +// assert.Less(t, "a", "b") +func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Less(t, e1, e2, msgAndArgs...) { + return + } + t.FailNow() +} + +// LessOrEqual asserts that the first element is less than or equal to the second +// +// assert.LessOrEqual(t, 1, 2) +// assert.LessOrEqual(t, 2, 2) +// assert.LessOrEqual(t, "a", "b") +// assert.LessOrEqual(t, "b", "b") +func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.LessOrEqual(t, e1, e2, msgAndArgs...) { + return + } + t.FailNow() +} + +// LessOrEqualf asserts that the first element is less than or equal to the second +// +// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.LessOrEqualf(t, e1, e2, msg, args...) { + return + } + t.FailNow() +} + +// Lessf asserts that the first element is less than the second +// +// assert.Lessf(t, 1, 2, "error message %s", "formatted") +// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// assert.Lessf(t, "a", "b", "error message %s", "formatted") +func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Lessf(t, e1, e2, msg, args...) { + return + } + t.FailNow() +} + +// Negative asserts that the specified element is negative +// +// assert.Negative(t, -1) +// assert.Negative(t, -1.23) +func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Negative(t, e, msgAndArgs...) { + return + } + t.FailNow() +} + +// Negativef asserts that the specified element is negative +// +// assert.Negativef(t, -1, "error message %s", "formatted") +// assert.Negativef(t, -1.23, "error message %s", "formatted") +func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Negativef(t, e, msg, args...) { + return + } + t.FailNow() +} + +// Nil asserts that the specified object is nil. +// +// assert.Nil(t, err) +func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Nil(t, object, msgAndArgs...) { + return + } + t.FailNow() +} + +// Nilf asserts that the specified object is nil. +// +// assert.Nilf(t, err, "error message %s", "formatted") +func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Nilf(t, object, msg, args...) { + return + } + t.FailNow() +} + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoError(t, err) { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoError(t TestingT, err error, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NoError(t, err, msgAndArgs...) { + return + } + t.FailNow() +} + +// NoErrorf asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoErrorf(t, err, "error message %s", "formatted") { +// assert.Equal(t, expectedObj, actualObj) +// } +func NoErrorf(t TestingT, err error, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NoErrorf(t, err, msg, args...) { + return + } + t.FailNow() +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContains(t, "Hello World", "Earth") +// assert.NotContains(t, ["Hello", "World"], "Earth") +// assert.NotContains(t, {"Hello": "World"}, "Earth") +func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotContains(t, s, contains, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotContainsf(t, s, contains, msg, args...) { + return + } + t.FailNow() +} + +// NotEqual asserts that the specified values are NOT equal. +// +// assert.NotEqual(t, obj1, obj2) +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqual(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotEqual(t, expected, actual, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotEqualValues asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValues(t, obj1, obj2) +func NotEqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotEqualValues(t, expected, actual, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotEqualValuesf asserts that two objects are not equal even when converted to the same type +// +// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotEqualValuesf(t, expected, actual, msg, args...) { + return + } + t.FailNow() +} + +// NotEqualf asserts that the specified values are NOT equal. +// +// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// +// Pointer variable equality is determined based on the equality of the +// referenced values (as opposed to the memory addresses). +func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotEqualf(t, expected, actual, msg, args...) { + return + } + t.FailNow() +} + +// NotNil asserts that the specified object is not nil. +// +// assert.NotNil(t, err) +func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotNil(t, object, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotNilf asserts that the specified object is not nil. +// +// assert.NotNilf(t, err, "error message %s", "formatted") +func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotNilf(t, object, msg, args...) { + return + } + t.FailNow() +} + +// Positive asserts that the specified element is positive +// +// assert.Positive(t, 1) +// assert.Positive(t, 1.23) +func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Positive(t, e, msgAndArgs...) { + return + } + t.FailNow() +} + +// Positivef asserts that the specified element is positive +// +// assert.Positivef(t, 1, "error message %s", "formatted") +// assert.Positivef(t, 1.23, "error message %s", "formatted") +func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Positivef(t, e, msg, args...) { + return + } + t.FailNow() +} + +// True asserts that the specified value is true. +// +// assert.True(t, myBool) +func True(t TestingT, value bool, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.True(t, value, msgAndArgs...) { + return + } + t.FailNow() +} + +// Truef asserts that the specified value is true. +// +// assert.Truef(t, myBool, "error message %s", "formatted") +func Truef(t TestingT, value bool, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.Truef(t, value, msg, args...) { + return + } + t.FailNow() +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.WithinDuration(t, expected, actual, delta, msgAndArgs...) { + return + } + t.FailNow() +} + +// WithinDurationf asserts that the two times are within duration delta of each other. +// +// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.WithinDurationf(t, expected, actual, delta, msg, args...) { + return + } + t.FailNow() +} diff --git a/drivers/mongov2/internal/serverselector/server_selector.go b/drivers/mongov2/internal/serverselector/server_selector.go new file mode 100644 index 0000000..86b3373 --- /dev/null +++ b/drivers/mongov2/internal/serverselector/server_selector.go @@ -0,0 +1,359 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package serverselector + +import ( + "fmt" + "math" + "time" + + "go.mongodb.org/mongo-driver/v2/mongo/readpref" + "go.mongodb.org/mongo-driver/v2/tag" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" +) + +// Composite combines multiple selectors into a single selector by applying them +// in order to the candidates list. +// +// For example, if the initial candidates list is [s0, s1, s2, s3] and two +// selectors are provided where the first matches s0 and s1 and the second +// matches s1 and s2, the following would occur during server selection: +// +// 1. firstSelector([s0, s1, s2, s3]) -> [s0, s1] +// 2. secondSelector([s0, s1]) -> [s1] +// +// The final list of candidates returned by the composite selector would be +// [s1]. +type Composite struct { + Selectors []description.ServerSelector +} + +var _ description.ServerSelector = &Composite{} + +// SelectServer combines multiple selectors into a single selector. +func (selector *Composite) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + var err error + for _, sel := range selector.Selectors { + candidates, err = sel.SelectServer(topo, candidates) + if err != nil { + return nil, err + } + } + + return candidates, nil +} + +// Latency creates a ServerSelector which selects servers based on their average +// RTT values. +type Latency struct { + Latency time.Duration +} + +var _ description.ServerSelector = &Latency{} + +// SelectServer selects servers based on average RTT. +func (selector *Latency) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if selector.Latency < 0 { + return candidates, nil + } + if topo.Kind == description.TopologyKindLoadBalanced { + // In LoadBalanced mode, there should only be one server in the topology and + // it must be selected. + return candidates, nil + } + + switch len(candidates) { + case 0, 1: + return candidates, nil + default: + min := time.Duration(math.MaxInt64) + for _, candidate := range candidates { + if candidate.AverageRTTSet { + if candidate.AverageRTT < min { + min = candidate.AverageRTT + } + } + } + + if min == math.MaxInt64 { + return candidates, nil + } + + max := min + selector.Latency + + viableIndexes := make([]int, 0, len(candidates)) + for i, candidate := range candidates { + if candidate.AverageRTTSet { + if candidate.AverageRTT <= max { + viableIndexes = append(viableIndexes, i) + } + } + } + if len(viableIndexes) == len(candidates) { + return candidates, nil + } + result := make([]description.Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } + return result, nil + } +} + +// ReadPref selects servers based on the provided read preference. +type ReadPref struct { + ReadPref *readpref.ReadPref + IsOutputAggregate bool +} + +var _ description.ServerSelector = &ReadPref{} + +// SelectServer selects servers based on read preference. +func (selector *ReadPref) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if topo.Kind == description.TopologyKindLoadBalanced { + // In LoadBalanced mode, there should only be one server in the topology and + // it must be selected. We check this before checking MaxStaleness support + // because there's no monitoring in this mode, so the candidate server + // wouldn't have a wire version set, which would result in an error. + return candidates, nil + } + + switch topo.Kind { + case description.TopologyKindSingle: + return candidates, nil + case description.TopologyKindReplicaSetNoPrimary, description.TopologyKindReplicaSetWithPrimary: + return selectForReplicaSet(selector.ReadPref, selector.IsOutputAggregate, topo, candidates) + case description.TopologyKindSharded: + return selectByKind(candidates, description.ServerKindMongos), nil + } + + return nil, nil +} + +// Write selects all the writable servers. +type Write struct{} + +var _ description.ServerSelector = &Write{} + +// SelectServer selects all writable servers. +func (selector *Write) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + switch topo.Kind { + case description.TopologyKindSingle, description.TopologyKindLoadBalanced: + return candidates, nil + default: + // Determine the capacity of the results slice. + selected := 0 + for _, candidate := range candidates { + switch candidate.Kind { + case description.ServerKindMongos, description.ServerKindRSPrimary, description.ServerKindStandalone: + selected++ + } + } + + // Append candidates to the results slice. + result := make([]description.Server, 0, selected) + for _, candidate := range candidates { + switch candidate.Kind { + case description.ServerKindMongos, description.ServerKindRSPrimary, description.ServerKindStandalone: + result = append(result, candidate) + } + } + return result, nil + } +} + +// Func is a function that can be used as a ServerSelector. +type Func func(description.Topology, []description.Server) ([]description.Server, error) + +// SelectServer implements the ServerSelector interface. +func (ssf Func) SelectServer( + t description.Topology, + s []description.Server, +) ([]description.Server, error) { + return ssf(t, s) +} + +func verifyMaxStaleness(rp *readpref.ReadPref, topo description.Topology) error { + maxStaleness, set := rp.MaxStaleness() + if !set { + return nil + } + + if maxStaleness < 90*time.Second { + return fmt.Errorf("max staleness (%s) must be greater than or equal to 90s", maxStaleness) + } + + if len(topo.Servers) < 1 { + // Maybe we should return an error here instead? + return nil + } + + // we'll assume all candidates have the same heartbeat interval. + s := topo.Servers[0] + idleWritePeriod := 10 * time.Second + + if maxStaleness < s.HeartbeatInterval+idleWritePeriod { + return fmt.Errorf( + "max staleness (%s) must be greater than or equal to the heartbeat interval (%s) plus idle write period (%s)", + maxStaleness, s.HeartbeatInterval, idleWritePeriod, + ) + } + + return nil +} + +func selectByKind(candidates []description.Server, kind description.ServerKind) []description.Server { + // Record the indices of viable candidates first and then append those to the returned slice + // to avoid appending costly Server structs directly as an optimization. + viableIndexes := make([]int, 0, len(candidates)) + for i, s := range candidates { + if s.Kind == kind { + viableIndexes = append(viableIndexes, i) + } + } + if len(viableIndexes) == len(candidates) { + return candidates + } + result := make([]description.Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } + return result +} + +func selectSecondaries(rp *readpref.ReadPref, candidates []description.Server) []description.Server { + secondaries := selectByKind(candidates, description.ServerKindRSSecondary) + if len(secondaries) == 0 { + return secondaries + } + if maxStaleness, set := rp.MaxStaleness(); set { + primaries := selectByKind(candidates, description.ServerKindRSPrimary) + if len(primaries) == 0 { + baseTime := secondaries[0].LastWriteTime + for i := 1; i < len(secondaries); i++ { + if secondaries[i].LastWriteTime.After(baseTime) { + baseTime = secondaries[i].LastWriteTime + } + } + + var selected []description.Server + for _, secondary := range secondaries { + estimatedStaleness := baseTime.Sub(secondary.LastWriteTime) + secondary.HeartbeatInterval + if estimatedStaleness <= maxStaleness { + selected = append(selected, secondary) + } + } + + return selected + } + + primary := primaries[0] + + var selected []description.Server + for _, secondary := range secondaries { + estimatedStaleness := secondary.LastUpdateTime.Sub(secondary.LastWriteTime) - + primary.LastUpdateTime.Sub(primary.LastWriteTime) + secondary.HeartbeatInterval + if estimatedStaleness <= maxStaleness { + selected = append(selected, secondary) + } + } + return selected + } + + return secondaries +} + +func selectByTagSet(candidates []description.Server, tagSets []tag.Set) []description.Server { + if len(tagSets) == 0 { + return candidates + } + + for _, ts := range tagSets { + // If this tag set is empty, we can take a fast path because the empty list + // is a subset of all tag sets, so all candidate servers will be selected. + if len(ts) == 0 { + return candidates + } + + var results []description.Server + for _, s := range candidates { + // ts is non-empty, so only servers with a non-empty set of tags need to be checked. + if len(s.Tags) > 0 && s.Tags.ContainsAll(ts) { + results = append(results, s) + } + } + + if len(results) > 0 { + return results + } + } + + return []description.Server{} +} + +func selectForReplicaSet( + rp *readpref.ReadPref, + isOutputAggregate bool, + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if err := verifyMaxStaleness(rp, topo); err != nil { + return nil, err + } + + // If underlying operation is an aggregate with an output stage, only apply read preference + // if all candidates are 5.0+. Otherwise, operate under primary read preference. + if isOutputAggregate { + for _, s := range candidates { + if s.WireVersion.Max < 13 { + return selectByKind(candidates, description.ServerKindRSPrimary), nil + } + } + } + + switch rp.Mode() { + case readpref.PrimaryMode: + return selectByKind(candidates, description.ServerKindRSPrimary), nil + case readpref.PrimaryPreferredMode: + selected := selectByKind(candidates, description.ServerKindRSPrimary) + + if len(selected) == 0 { + selected = selectSecondaries(rp, candidates) + return selectByTagSet(selected, rp.TagSets()), nil + } + + return selected, nil + case readpref.SecondaryPreferredMode: + selected := selectSecondaries(rp, candidates) + selected = selectByTagSet(selected, rp.TagSets()) + if len(selected) > 0 { + return selected, nil + } + return selectByKind(candidates, description.ServerKindRSPrimary), nil + case readpref.SecondaryMode: + selected := selectSecondaries(rp, candidates) + return selectByTagSet(selected, rp.TagSets()), nil + case readpref.NearestMode: + selected := selectByKind(candidates, description.ServerKindRSPrimary) + selected = append(selected, selectSecondaries(rp, candidates)...) + return selectByTagSet(selected, rp.TagSets()), nil + } + + return nil, fmt.Errorf("unsupported mode: %d", rp.Mode()) +} diff --git a/drivers/mongov2/internal/serverselector/server_selector_test.go b/drivers/mongov2/internal/serverselector/server_selector_test.go new file mode 100644 index 0000000..ec26853 --- /dev/null +++ b/drivers/mongov2/internal/serverselector/server_selector_test.go @@ -0,0 +1,1278 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package serverselector + +import ( + "errors" + "io/ioutil" + "path" + "testing" + "time" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/driverutil" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/spectest" + "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo/address" + "go.mongodb.org/mongo-driver/v2/mongo/readpref" + "go.mongodb.org/mongo-driver/v2/tag" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" +) + +type lastWriteDate struct { + LastWriteDate int64 `bson:"lastWriteDate"` +} + +type serverDesc struct { + Address string `bson:"address"` + AverageRTTMS *int `bson:"avg_rtt_ms"` + MaxWireVersion *int32 `bson:"maxWireVersion"` + LastUpdateTime *int `bson:"lastUpdateTime"` + LastWrite *lastWriteDate `bson:"lastWrite"` + Type string `bson:"type"` + Tags map[string]string `bson:"tags"` +} + +type topDesc struct { + Type string `bson:"type"` + Servers []*serverDesc `bson:"servers"` +} + +type readPref struct { + MaxStaleness *int `bson:"maxStalenessSeconds"` + Mode string `bson:"mode"` + TagSets []map[string]string `bson:"tag_sets"` +} + +type testCase struct { + TopologyDescription topDesc `bson:"topology_description"` + Operation string `bson:"operation"` + ReadPreference readPref `bson:"read_preference"` + SuitableServers []*serverDesc `bson:"suitable_servers"` + InLatencyWindow []*serverDesc `bson:"in_latency_window"` + HeartbeatFrequencyMS *int `bson:"heartbeatFrequencyMS"` + Error *bool +} + +func serverKindFromString(t *testing.T, s string) description.ServerKind { + t.Helper() + + switch s { + case "Standalone": + return description.ServerKindStandalone + case "RSOther": + return description.ServerKindRSMember + case "RSPrimary": + return description.ServerKindRSPrimary + case "RSSecondary": + return description.ServerKindRSSecondary + case "RSArbiter": + return description.ServerKindRSArbiter + case "RSGhost": + return description.ServerKindRSGhost + case "Mongos": + return description.ServerKindMongos + case "LoadBalancer": + return description.ServerKindLoadBalancer + case "PossiblePrimary", "Unknown": + // Go does not have a PossiblePrimary server type and per the SDAM spec, this type is synonymous with Unknown. + return description.Unknown + default: + t.Fatalf("unrecognized server kind: %q", s) + } + + return description.Unknown +} + +func topologyKindFromString(t *testing.T, s string) description.TopologyKind { + t.Helper() + + switch s { + case "Single": + return description.TopologyKindSingle + case "ReplicaSet": + return description.TopologyKindReplicaSet + case "ReplicaSetNoPrimary": + return description.TopologyKindReplicaSetNoPrimary + case "ReplicaSetWithPrimary": + return description.TopologyKindReplicaSetWithPrimary + case "Sharded": + return description.TopologyKindSharded + case "LoadBalanced": + return description.TopologyKindLoadBalanced + case "Unknown": + return description.Unknown + default: + t.Fatalf("unrecognized topology kind: %q", s) + } + + return description.Unknown +} + +func anyTagsInSets(sets []tag.Set) bool { + for _, set := range sets { + if len(set) > 0 { + return true + } + } + + return false +} + +func findServerByAddress(servers []description.Server, address string) description.Server { + for _, server := range servers { + if server.Addr.String() == address { + return server + } + } + + return description.Server{} +} + +func compareServers(t *testing.T, expected []*serverDesc, actual []description.Server) { + require.Equal(t, len(expected), len(actual)) + + for _, expectedServer := range expected { + actualServer := findServerByAddress(actual, expectedServer.Address) + require.NotNil(t, actualServer) + + if expectedServer.AverageRTTMS != nil { + require.Equal(t, *expectedServer.AverageRTTMS, int(actualServer.AverageRTT/time.Millisecond)) + } + + require.Equal(t, expectedServer.Type, actualServer.Kind.String()) + + require.Equal(t, len(expectedServer.Tags), len(actualServer.Tags)) + for _, actualTag := range actualServer.Tags { + expectedTag, ok := expectedServer.Tags[actualTag.Name] + require.True(t, ok) + require.Equal(t, expectedTag, actualTag.Value) + } + } +} + +const maxStalenessTestsDir = "../../testdata/max-staleness" + +// Test case for all max staleness spec tests. +func TestMaxStalenessSpec(t *testing.T) { + for _, topology := range [...]string{ + "ReplicaSetNoPrimary", + "ReplicaSetWithPrimary", + "Sharded", + "Single", + "Unknown", + } { + for _, file := range spectest.FindJSONFilesInDir(t, + path.Join(maxStalenessTestsDir, topology)) { + + runTest(t, maxStalenessTestsDir, topology, file) + } + } +} + +const selectorTestsDir = "../../testdata/server-selection/server_selection" + +func selectServers(t *testing.T, test *testCase) error { + servers := make([]description.Server, 0, len(test.TopologyDescription.Servers)) + + // Times in the JSON files are given as offsets from an unspecified time, but the driver + // stores the lastWrite field as a timestamp, so we arbitrarily choose the current time + // as the base to offset from. + baseTime := time.Now() + + for _, serverDescription := range test.TopologyDescription.Servers { + server := description.Server{ + Addr: address.Address(serverDescription.Address), + Kind: serverKindFromString(t, serverDescription.Type), + } + + if serverDescription.AverageRTTMS != nil { + server.AverageRTT = time.Duration(*serverDescription.AverageRTTMS) * time.Millisecond + server.AverageRTTSet = true + } + + if test.HeartbeatFrequencyMS != nil { + server.HeartbeatInterval = time.Duration(*test.HeartbeatFrequencyMS) * time.Millisecond + } + + if serverDescription.LastUpdateTime != nil { + ms := int64(*serverDescription.LastUpdateTime) + server.LastUpdateTime = time.Unix(ms/1e3, ms%1e3/1e6) + } + + if serverDescription.LastWrite != nil { + i := serverDescription.LastWrite.LastWriteDate + + timeWithOffset := baseTime.Add(time.Duration(i) * time.Millisecond) + server.LastWriteTime = timeWithOffset + } + + if serverDescription.MaxWireVersion != nil { + versionRange := driverutil.NewVersionRange(0, *serverDescription.MaxWireVersion) + server.WireVersion = &versionRange + } + + if serverDescription.Tags != nil { + server.Tags = tag.NewTagSetFromMap(serverDescription.Tags) + } + + if test.ReadPreference.MaxStaleness != nil && server.WireVersion == nil { + server.WireVersion = &description.VersionRange{Max: 21} + } + + servers = append(servers, server) + } + + c := description.Topology{ + Kind: topologyKindFromString(t, test.TopologyDescription.Type), + Servers: servers, + } + + if len(test.ReadPreference.Mode) == 0 { + test.ReadPreference.Mode = "Primary" + } + + readprefMode, err := readpref.ModeFromString(test.ReadPreference.Mode) + if err != nil { + return err + } + + options := make([]readpref.Option, 0, 1) + + tagSets := tag.NewTagSetsFromMaps(test.ReadPreference.TagSets) + if anyTagsInSets(tagSets) { + options = append(options, readpref.WithTagSets(tagSets...)) + } + + if test.ReadPreference.MaxStaleness != nil { + s := time.Duration(*test.ReadPreference.MaxStaleness) * time.Second + options = append(options, readpref.WithMaxStaleness(s)) + } + + rp, err := readpref.New(readprefMode, options...) + if err != nil { + return err + } + + var selector description.ServerSelector + + selector = &ReadPref{ReadPref: rp} + if test.Operation == "write" { + selector = &Composite{ + Selectors: []description.ServerSelector{&Write{}, selector}, + } + } + + result, err := selector.SelectServer(c, c.Servers) + if err != nil { + return err + } + + compareServers(t, test.SuitableServers, result) + + latencySelector := &Latency{Latency: time.Duration(15) * time.Millisecond} + selector = &Composite{ + Selectors: []description.ServerSelector{selector, latencySelector}, + } + + result, err = selector.SelectServer(c, c.Servers) + if err != nil { + return err + } + + compareServers(t, test.InLatencyWindow, result) + + return nil +} + +func runTest(t *testing.T, testsDir string, directory string, filename string) { + filepath := path.Join(testsDir, directory, filename) + content, err := ioutil.ReadFile(filepath) + require.NoError(t, err) + + // Remove ".json" from filename. + filename = filename[:len(filename)-5] + testName := directory + "/" + filename + ":" + + t.Run(testName, func(t *testing.T) { + var test testCase + require.NoError(t, bson.UnmarshalExtJSON(content, true, &test)) + + err := selectServers(t, &test) + + if test.Error == nil || !*test.Error { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) +} + +// Test case for all SDAM spec tests. +func TestServerSelectionSpec(t *testing.T) { + for _, topology := range [...]string{ + "ReplicaSetNoPrimary", + "ReplicaSetWithPrimary", + "Sharded", + "Single", + "Unknown", + "LoadBalanced", + } { + for _, subdir := range [...]string{"read", "write"} { + subdirPath := path.Join(topology, subdir) + + for _, file := range spectest.FindJSONFilesInDir(t, + path.Join(selectorTestsDir, subdirPath)) { + + runTest(t, selectorTestsDir, subdirPath, file) + } + } + } +} + +func TestServerSelection(t *testing.T) { + noerr := func(t *testing.T, err error) { + if err != nil { + t.Errorf("Unepexted error: %v", err) + t.FailNow() + } + } + + t.Run("WriteSelector", func(t *testing.T) { + testCases := []struct { + name string + desc description.Topology + start int + end int + }{ + { + name: "ReplicaSetWithPrimary", + desc: description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), Kind: description.ServerKindRSPrimary}, + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindRSSecondary}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindRSSecondary}, + }, + }, + start: 0, + end: 1, + }, + { + name: "ReplicaSetNoPrimary", + desc: description.Topology{ + Kind: description.TopologyKindReplicaSetNoPrimary, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindRSSecondary}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindRSSecondary}, + }, + }, + start: 0, + end: 0, + }, + { + name: "Sharded", + desc: description.Topology{ + Kind: description.TopologyKindSharded, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindMongos}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindMongos}, + }, + }, + start: 0, + end: 2, + }, + { + name: "Single", + desc: description.Topology{ + Kind: description.TopologyKindSingle, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindStandalone}, + }, + }, + start: 0, + end: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := (&Write{}).SelectServer(tc.desc, tc.desc.Servers) + noerr(t, err) + if len(result) != tc.end-tc.start { + t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) + } + if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { + t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) + } + }) + } + }) + t.Run("LatencySelector", func(t *testing.T) { + testCases := []struct { + name string + desc description.Topology + start int + end int + }{ + { + name: "NoRTTSet", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017")}, + {Addr: address.Address("localhost:27018")}, + {Addr: address.Address("localhost:27019")}, + }, + }, + start: 0, + end: 3, + }, + { + name: "MultipleServers PartialNoRTTSet", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27019")}, + }, + }, + start: 0, + end: 2, + }, + { + name: "MultipleServers", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27019"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, + }, + }, + start: 0, + end: 2, + }, + { + name: "No Servers", + desc: description.Topology{Servers: []description.Server{}}, + start: 0, + end: 0, + }, + { + name: "1 Server", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, + }, + }, + start: 0, + end: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := (&Latency{Latency: 20 * time.Second}).SelectServer(tc.desc, tc.desc.Servers) + noerr(t, err) + if len(result) != tc.end-tc.start { + t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) + } + if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { + t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) + } + }) + } + }) +} + +var readPrefTestPrimary = description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestSecondary1 = description.Server{ + Addr: address.Address("localhost:27018"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 13, 58, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSSecondary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestSecondary2 = description.Server{ + Addr: address.Address("localhost:27018"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSSecondary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "2"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestTopology = description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, +} + +func TestSelector_Sharded(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: []description.Server{s}, + } + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{s}, result) +} + +func BenchmarkLatencySelector(b *testing.B) { + for _, bcase := range []struct { + name string + serversHook func(servers []description.Server) + }{ + { + name: "AllFit", + serversHook: func([]description.Server) {}, + }, + { + name: "AllButOneFit", + serversHook: func(servers []description.Server) { + servers[0].AverageRTT = 2 * time.Second + }, + }, + { + name: "HalfFit", + serversHook: func(servers []description.Server) { + for i := 0; i < len(servers); i += 2 { + servers[i].AverageRTT = 2 * time.Second + } + }, + }, + { + name: "OneFit", + serversHook: func(servers []description.Server) { + for i := 1; i < len(servers); i++ { + servers[i].AverageRTT = 2 * time.Second + } + }, + }, + } { + bcase := bcase + + b.Run(bcase.name, func(b *testing.B) { + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + AverageRTTSet: true, + AverageRTT: time.Second, + } + servers := make([]description.Server, 100) + for i := 0; i < len(servers); i++ { + servers[i] = s + } + bcase.serversHook(servers) + // this will make base 1 sec latency < min (0.5) + conf (1) + // and high latency 2 higher than the threshold + servers[99].AverageRTT = 500 * time.Millisecond + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: servers, + } + + b.ResetTimer() + b.RunParallel(func(p *testing.PB) { + b.ReportAllocs() + for p.Next() { + _, _ = (&Latency{Latency: time.Second}).SelectServer(c, c.Servers) + } + }) + }) + } +} + +func BenchmarkSelector_Sharded(b *testing.B) { + for _, bcase := range []struct { + name string + serversHook func(servers []description.Server) + }{ + { + name: "AllFit", + serversHook: func([]description.Server) {}, + }, + { + name: "AllButOneFit", + serversHook: func(servers []description.Server) { + servers[0].Kind = description.ServerKindLoadBalancer + }, + }, + { + name: "HalfFit", + serversHook: func(servers []description.Server) { + for i := 0; i < len(servers); i += 2 { + servers[i].Kind = description.ServerKindLoadBalancer + } + }, + }, + { + name: "OneFit", + serversHook: func(servers []description.Server) { + for i := 1; i < len(servers); i++ { + servers[i].Kind = description.ServerKindLoadBalancer + } + }, + }, + } { + bcase := bcase + + b.Run(bcase.name, func(b *testing.B) { + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + servers := make([]description.Server, 100) + for i := 0; i < len(servers); i++ { + servers[i] = s + } + bcase.serversHook(servers) + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: servers, + } + + b.ResetTimer() + b.RunParallel(func(p *testing.PB) { + b.ReportAllocs() + for p.Next() { + _, _ = (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + } + }) + }) + } +} + +func Benchmark_SelectServer_SelectServer(b *testing.B) { + topology := description.Topology{Kind: description.TopologyKindReplicaSet} // You can change the topology as needed + candidates := []description.Server{ + {Kind: description.ServerKindMongos}, + {Kind: description.ServerKindRSPrimary}, + {Kind: description.ServerKindStandalone}, + } + + selector := &Write{} // Assuming this is the receiver type + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := selector.SelectServer(topology, candidates) + if err != nil { + b.Fatalf("Error selecting server: %v", err) + } + } +} + +func TestSelector_Single(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindSingle, + Servers: []description.Server{s}, + } + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{s}, result) +} + +func TestSelector_Primary(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_Primary_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_PrimaryPreferred(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_ignores_tags(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_PrimaryPreferred_with_no_primary_and_tags(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_PrimaryPreferred_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_SecondaryPreferred_with_tags_that_do_not_match_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_SecondaryPreferred_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_SecondaryPreferred_with_no_secondaries_or_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_SecondaryPreferred_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_empty_tag_set(t *testing.T) { + t.Parallel() + + primaryNoTags := description.Server{ + Addr: address.Address("localhost:27017"), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + firstSecondaryNoTags := description.Server{ + Addr: address.Address("localhost:27018"), + Kind: description.ServerKindRSSecondary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + secondSecondaryNoTags := description.Server{ + Addr: address.Address("localhost:27019"), + Kind: description.ServerKindRSSecondary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + topologyNoTags := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{primaryNoTags, firstSecondaryNoTags, secondSecondaryNoTags}, + } + + nonMatchingSet := tag.Set{ + {Name: "foo", Value: "bar"}, + } + emptyTagSet := tag.Set{} + rp := readpref.Secondary( + readpref.WithTagSets(nonMatchingSet, emptyTagSet), + ) + + result, err := (&ReadPref{ReadPref: rp}).SelectServer(topologyNoTags, topologyNoTags.Servers) + assert.Nil(t, err, "SelectServer error: %v", err) + expectedResult := []description.Server{firstSecondaryNoTags, secondSecondaryNoTags} + assert.Equal(t, expectedResult, result, "expected result %v, got %v", expectedResult, result) +} + +func TestSelector_Secondary_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Secondary_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Secondary_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithTags("a", "1"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary1}, result) +} + +func TestSelector_Nearest_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Nearest_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_Nearest_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Max_staleness_is_less_than_90_seconds(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(50) * time.Second), + ) + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{s}, + } + + _, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.Error(t, err) +} + +func TestSelector_Max_staleness_is_too_low(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(100) * time.Second), + ) + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(100) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{s}, + } + + _, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.Error(t, err) +} + +func TestEqualServers(t *testing.T) { + int64ToPtr := func(i64 int64) *int64 { return &i64 } + + t.Run("equals", func(t *testing.T) { + defaultServer := description.Server{} + // Only some of the Server fields affect equality + testCases := []struct { + name string + server description.Server + equal bool + }{ + {"empty", description.Server{}, true}, + {"address", description.Server{Addr: address.Address("foo")}, true}, + {"arbiters", description.Server{Arbiters: []string{"foo"}}, false}, + {"rtt", description.Server{AverageRTT: time.Second}, true}, + {"compression", description.Server{Compression: []string{"foo"}}, true}, + {"canonicalAddr", description.Server{CanonicalAddr: address.Address("foo")}, false}, + {"electionID", description.Server{ElectionID: bson.NewObjectID()}, false}, + {"heartbeatInterval", description.Server{HeartbeatInterval: time.Second}, true}, + {"hosts", description.Server{Hosts: []string{"foo"}}, false}, + {"lastError", description.Server{LastError: errors.New("foo")}, false}, + {"lastUpdateTime", description.Server{LastUpdateTime: time.Now()}, true}, + {"lastWriteTime", description.Server{LastWriteTime: time.Now()}, true}, + {"maxBatchCount", description.Server{MaxBatchCount: 1}, true}, + {"maxDocumentSize", description.Server{MaxDocumentSize: 1}, true}, + {"maxMessageSize", description.Server{MaxMessageSize: 1}, true}, + {"members", description.Server{Members: []address.Address{address.Address("foo")}}, true}, + {"passives", description.Server{Passives: []string{"foo"}}, false}, + {"passive", description.Server{Passive: true}, true}, + {"primary", description.Server{Primary: address.Address("foo")}, false}, + {"readOnly", description.Server{ReadOnly: true}, true}, + { + "sessionTimeoutMinutes", + description.Server{ + SessionTimeoutMinutes: int64ToPtr(1), + }, + false, + }, + {"setName", description.Server{SetName: "foo"}, false}, + {"setVersion", description.Server{SetVersion: 1}, false}, + {"tags", description.Server{Tags: tag.Set{tag.Tag{"foo", "bar"}}}, false}, + {"topologyVersion", description.Server{TopologyVersion: &description.TopologyVersion{bson.NewObjectID(), 0}}, false}, + {"kind", description.Server{Kind: description.ServerKindStandalone}, false}, + {"wireVersion", description.Server{WireVersion: &description.VersionRange{1, 2}}, false}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := driverutil.EqualServers(defaultServer, tc.server) + assert.Equal(t, actual, tc.equal, "expected %v, got %v", tc.equal, actual) + }) + } + }) +} + +func TestVersionRangeIncludes(t *testing.T) { + t.Parallel() + + subject := driverutil.NewVersionRange(1, 3) + + tests := []struct { + n int32 + expected bool + }{ + {0, false}, + {1, true}, + {2, true}, + {3, true}, + {4, false}, + {10, false}, + } + + for _, test := range tests { + actual := driverutil.VersionRangeIncludes(subject, test.n) + if actual != test.expected { + t.Fatalf("expected %v to be %t", test.n, test.expected) + } + } +} diff --git a/drivers/mongov2/internal/spectest/spectest.go b/drivers/mongov2/internal/spectest/spectest.go new file mode 100644 index 0000000..6a9ec68 --- /dev/null +++ b/drivers/mongov2/internal/spectest/spectest.go @@ -0,0 +1,35 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package spectest + +import ( + "io/ioutil" + "path" + "testing" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" +) + +// FindJSONFilesInDir finds the JSON files in a directory. +func FindJSONFilesInDir(t *testing.T, dir string) []string { + t.Helper() + + files := make([]string, 0) + + entries, err := ioutil.ReadDir(dir) + require.NoError(t, err) + + for _, entry := range entries { + if entry.IsDir() || path.Ext(entry.Name()) != ".json" { + continue + } + + files = append(files, entry.Name()) + } + + return files +} diff --git a/drivers/mongov2/settings_test.go b/drivers/mongov2/settings_test.go index e08bec9..425a19b 100644 --- a/drivers/mongov2/settings_test.go +++ b/drivers/mongov2/settings_test.go @@ -90,13 +90,12 @@ func TestSettings_EnrichBy(t *testing.T) { got := tt.settings.EnrichBy(tt.args.external) - //assert.Equal(t, tt.want, got) - - t.Helper() assert.Equal(t, tt.want.CtxKey(), got.CtxKey()) assert.Equal(t, tt.want.Propagation(), got.Propagation()) assert.Equal(t, tt.want.Cancelable(), got.Cancelable()) assert.Equal(t, tt.want.TimeoutOrNil(), got.TimeoutOrNil()) + + assert.Equal(t, len(tt.want.(Settings).SessionOpts().List()), len(got.(Settings).SessionOpts().List())) }) } } diff --git a/drivers/mongov2/transaction_test.go b/drivers/mongov2/transaction_test.go new file mode 100644 index 0000000..b4bd8ce --- /dev/null +++ b/drivers/mongov2/transaction_test.go @@ -0,0 +1,231 @@ +//go:build go1.21 + +package mongov2 + +import ( + "context" + "errors" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" + "github.com/avito-tech/go-transaction-manager/trm/v2/drivers/mock" + + "github.com/avito-tech/go-transaction-manager/trm/v2" + "github.com/avito-tech/go-transaction-manager/trm/v2/manager" + "github.com/avito-tech/go-transaction-manager/trm/v2/settings" +) + +type user struct { + ID bson.ObjectID `bson:"_id,omitempty"` +} + +func TestTransaction(t *testing.T) { + t.Parallel() + + type args struct { + ctx context.Context + } + + type fields struct { + settings trm.Settings + } + + testErr := errors.New("error test") + doNil := func(_ *mtest.T, _ context.Context) error { + return nil + } + defaultFields := func(_ *mtest.T) fields { + return fields{ + settings: MustSettings(settings.Must( + settings.WithPropagation(trm.PropagationRequiresNew), + ), WithSessionOpts(&options.SessionOptionsBuilder{})), + } + } + + mt := mtest.New( + t, + mtest.NewOptions().ClientType(mtest.Mock), + ) + + tests := map[string]struct { + fields func(mt *mtest.T) fields + args args + do func(mt *mtest.T, ctx context.Context) error + wantErr assert.ErrorAssertionFunc + }{ + "success": { + fields: defaultFields, + args: args{ + ctx: context.Background(), + }, + do: doNil, + wantErr: assert.NoError, + }, + "begin_session_error": { + fields: func(_ *mtest.T) fields { + return fields{ + settings: MustSettings(settings.Must( + settings.WithPropagation(trm.PropagationNested), + ), WithSessionOpts((&options.SessionOptionsBuilder{}). + SetSnapshot(true). + SetCausalConsistency(true))), + } + }, + args: args{ + ctx: context.Background(), + }, + do: func(mt *mtest.T, _ context.Context) error { + require.NotNil(mt, 1, "should not be here") + + return nil + }, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.ErrorIs(t, err, trm.ErrBegin) + }, + }, + "begin_transaction_error": { + fields: func(_ *mtest.T) fields { + return fields{ + settings: MustSettings(settings.Must( + settings.WithPropagation(trm.PropagationNested), + ), WithTransactionOpts((&options.TransactionOptionsBuilder{}). + SetWriteConcern(&writeconcern.WriteConcern{W: 0}))), + } + }, + args: args{ + ctx: context.Background(), + }, + do: func(mt *mtest.T, _ context.Context) error { + require.NotNil(mt, 1, "should not be here") + + return nil + }, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.ErrorIs(t, err, trm.ErrBegin) + }, + }, + "commit_error": { + fields: defaultFields, + args: args{ + ctx: context.Background(), + }, + do: func(mt *mtest.T, ctx context.Context) error { + _, _ = mt.Coll.InsertOne(ctx, user{ + ID: bson.NewObjectID(), + }) + + return nil + }, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + var divErr mongo.CommandError + + return assert.ErrorAs(t, err, &divErr) && + assert.ErrorIs(t, err, trm.ErrCommit) + }, + }, + "rollback_after_error": { + fields: defaultFields, + args: args{ + ctx: context.Background(), + }, + do: func(mt *mtest.T, ctx context.Context) error { + s := mongo.SessionFromContext(ctx) + + require.NoError(mt, s.AbortTransaction(ctx)) + + return testErr + }, + wantErr: func(t assert.TestingT, err error, _ ...interface{}) bool { + return assert.ErrorIs(t, err, testErr) && + assert.ErrorIs(t, err, trm.ErrRollback) + }, + }, + } + for name, tt := range tests { + tt := tt + mt.Run(name, func(mt *mtest.T) { + mt.Parallel() + + log := mock.NewLog() + + f := tt.fields(mt) + + m := manager.Must( + NewDefaultFactory(mt.Client), + manager.WithLog(log), + manager.WithSettings(f.settings), + ) + + var tr trm.Transaction + err := m.Do(tt.args.ctx, func(ctx context.Context) error { + tr = trmcontext.DefaultManager.Default(ctx) + + var trNested trm.Transaction + err := m.Do(ctx, func(ctx context.Context) error { + trNested = trmcontext.DefaultManager.Default(ctx) + + require.NotNil(t, trNested) + + return tt.do(mt, ctx) + }) + + if trNested != nil { + require.False(t, trNested.IsActive()) + } + + return err + }) + + if tr != nil { + require.False(t, tr.IsActive()) + } + + if !tt.wantErr(t, err) { + return + } + }) + } +} + +func TestTransaction_awaitDone_byContext(t *testing.T) { + t.Parallel() + + mt := mtest.New( + t, + mtest.NewOptions(). + ClientType(mtest.Mock). + ShareClient(true), + ) + + wg := sync.WaitGroup{} + wg.Add(1) + + f := NewDefaultFactory(mt.Client) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + defer wg.Done() + + _, tr, err := f(ctx, settings.Must()) + + cancel() + <-time.After(time.Second) + + <-ctx.Done() + + require.NoError(mt, err) + require.False(mt, tr.IsActive()) + }() + + wg.Wait() +} From 8de456e6d21477fa5d3430917863906ff8caf3f4 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Tue, 1 Jul 2025 14:51:28 +0800 Subject: [PATCH 3/7] feat:update readme add mongov2 description --- README.md | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8b75e46..a6aca4c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Easiest way to get the perfect repository. Go 1.18) * [mongo-go-driver](https://github.com/mongodb/mongo-go-driver), [docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/drivers/mongo/v2) ( Go 1.13) +* [mongo-go-driver v2](https://github.com/mongodb/mongo-go-driver), [docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2) ( + Go 1.21) * [go-redis/redis](https://github.com/go-redis/redis), [docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/drivers/goredis8/v2) ( Go 1.17) * [pgx_v4](https://github.com/jackc/pgx/tree/v4), [docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/drivers/pgxv4/v2) ( @@ -48,14 +50,16 @@ The critical bugs are firstly solved for the most recent two Golang versions and `go get -u && go mod tidy` helps you. -**Note**: The go-transaction-manager uses some old dependencies to support backwards compatibility for old versions of Go. +**Note**: The go-transaction-manager uses some old dependencies to support backwards compatibility for old versions of +Go. ## Usage **To use multiple transactions from different databases**, you need to set CtxKey in [Settings](trm/settings.go) by [WithCtxKey](trm/settings/option.go) ([docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/trm/v2)). -**For nested transactions with different transaction managers**, you need to use [ChainedMW](trm/manager/chain.go) ([docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/trm/v2/manager)). +**For nested transactions with different transaction managers**, you need to +use [ChainedMW](trm/manager/chain.go) ([docs](https://pkg.go.dev/github.com/avito-tech/go-transaction-manager/trm/v2/manager)). **To skip a transaction rollback due to an error, use [ErrSkip](manager.go#L20) or [Skippable](manager.go#L24)** @@ -67,6 +71,7 @@ by [WithCtxKey](trm/settings/option.go) ([docs](https://pkg.go.dev/github.com/av * [jmoiron/sqlx](drivers/sqlx/example_test.go) * [gorm](drivers/gorm/example_test.go) * [mongo-go-driver](drivers/mongo/example_test.go) +* [mongo-go-driver v2](drivers/mongov2/example_test.go) * [go-redis/redis](drivers/goredis8/example_test.go) * [pgx_v4](drivers/pgxv4/example_test.go) * [pgx_v5](drivers/pgxv5/example_test.go) @@ -181,13 +186,14 @@ func (r *repo) Save(ctx context.Context, u *user) error { * To run all tests use `make test` or `make test.with_real_db` for integration tests. To run database by docker, there is [docker-compose.yaml](trm/drivers/test/docker-compose.yaml). + ```bash docker compose -f trm/drivers/test/docker-compose.yaml up ``` For full GitHub Actions run, you can use [act](https://github.com/nektos/act). -#### Running old go versions +#### Running old go versions To stop Golang upgrading set environment variable `GOTOOLCHAIN=local` . @@ -199,6 +205,7 @@ go1.16 install Use `-mod=readonly` to prevent `go.mod` modification. To run tests + ``` go1.16 test -race -mod=readonly ./... ``` @@ -206,19 +213,21 @@ go1.16 test -race -mod=readonly ./... ### How to bump up Golang version in CI/CD 1. Changes in [.github/workflows/main.yaml](.github/workflows/main.yaml). - 1. Add all old version of Go in `go-version:` for `tests-units` job. - 2. Update `go-version:` on current version of Go for `lint` and `tests-integration` jobs. + 1. Add all old version of Go in `go-version:` for `tests-units` job. + 2. Update `go-version:` on current version of Go for `lint` and `tests-integration` jobs. 2. Update build tags by replacing `build go1.xx` on new version. - ### Resolve problems with old version of dependencies -To build `go.mod` compatible for old version use `go mod tidy -compat=1.13` ([docs](https://go.dev/ref/mod#go-mod-tidy)). +To build `go.mod` compatible for old version use +`go mod tidy -compat=1.13` ([docs](https://go.dev/ref/mod#go-mod-tidy)). However, `--compat` doesn't always work correct and we need to set some library versions manually. 1. `go get go.uber.org/multierr@v1.9.0` in [trm](trm), [sql](drivers/sql), [sqlx](drivers/sqlx). 2. `go get github.com/mattn/go-sqlite3@v1.14.14` in [trm](trm), [sql](drivers/sql), [sqlx](drivers/sqlx). -3. `go get github.com/stretchr/testify@v1.8.2` in [trm](trm), [sql](drivers/sql), [sqlx](drivers/sqlx), [goredis8](drivers/goredis8), [mongo](drivers/mongo). -4. `go get github.com/jackc/pgconn@v1.14.2` in [pgxv4](drivers/pgxv4). Golang version was bumped up from 1.12 to 1.17 in pgconn v1.14.3. +3. `go get github.com/stretchr/testify@v1.8.2` + in [trm](trm), [sql](drivers/sql), [sqlx](drivers/sqlx), [goredis8](drivers/goredis8), [mongo](drivers/mongo). +4. `go get github.com/jackc/pgconn@v1.14.2` in [pgxv4](drivers/pgxv4). Golang version was bumped up from 1.12 to 1.17 in + pgconn v1.14.3. 5. `go get golang.org/x/text@v0.13.0` in [pgxv4](drivers/pgxv4). \ No newline at end of file From 946101ade7f399b596a2c744dc5aaf82030f8426 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Wed, 2 Jul 2025 12:26:39 +0800 Subject: [PATCH 4/7] feat:optimize unit test method calls --- drivers/mongov2/example_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/drivers/mongov2/example_test.go b/drivers/mongov2/example_test.go index 3fe0aa8..7bfec42 100644 --- a/drivers/mongov2/example_test.go +++ b/drivers/mongov2/example_test.go @@ -1,11 +1,11 @@ //go:build with_real_db -// +build with_real_db package mongov2_test import ( "context" "fmt" + trmmongo "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2" trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" @@ -18,7 +18,7 @@ import ( func Example() { ctx := context.Background() - client, err := mongo.Connect(ctx, options.Client(). + client, err := mongo.Connect(options.Client(). ApplyURI("mongodb://127.0.0.1:27017/?directConnection=true")) checkErr(err) defer client.Disconnect(ctx) From 241cdc44f426f7688d8d7164871f203c5f24f889 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Thu, 17 Jul 2025 10:56:55 +0800 Subject: [PATCH 5/7] feat:optimize import & copy the required internal packages (bsonutil, handshake, ptrutil) from the MongoDB Go Driver --- drivers/mongov2/context.go | 4 +- drivers/mongov2/context_test.go | 6 +- drivers/mongov2/contract.go | 1 + drivers/mongov2/example_test.go | 6 +- drivers/mongov2/internal/bsonutil/bsonutil.go | 62 +++++++++++++++ .../internal/driverutil/description.go | 7 +- .../mongov2/internal/handshake/handshake.go | 13 ++++ .../mongov2/internal/integtest/integtest.go | 5 +- .../mongov2/internal/mtest/global_state.go | 3 +- drivers/mongov2/internal/mtest/mongotest.go | 11 +-- drivers/mongov2/internal/mtest/setup.go | 3 +- drivers/mongov2/internal/ptrutil/int64.go | 39 ++++++++++ .../mongov2/internal/ptrutil/int64_test.go | 76 +++++++++++++++++++ drivers/mongov2/internal/ptrutil/ptr.go | 12 +++ .../serverselector/server_selector_test.go | 9 ++- drivers/mongov2/settings.go | 3 +- drivers/mongov2/settings_test.go | 3 +- drivers/mongov2/transaction.go | 1 + drivers/mongov2/transaction_test.go | 10 ++- 19 files changed, 247 insertions(+), 27 deletions(-) create mode 100644 drivers/mongov2/internal/bsonutil/bsonutil.go create mode 100644 drivers/mongov2/internal/handshake/handshake.go create mode 100644 drivers/mongov2/internal/ptrutil/int64.go create mode 100644 drivers/mongov2/internal/ptrutil/int64_test.go create mode 100644 drivers/mongov2/internal/ptrutil/ptr.go diff --git a/drivers/mongov2/context.go b/drivers/mongov2/context.go index 125886c..bd3a3fa 100644 --- a/drivers/mongov2/context.go +++ b/drivers/mongov2/context.go @@ -2,9 +2,11 @@ package mongov2 import ( "context" - "github.com/avito-tech/go-transaction-manager/trm/v2" + "go.mongodb.org/mongo-driver/v2/mongo" + "github.com/avito-tech/go-transaction-manager/trm/v2" + trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" ) diff --git a/drivers/mongov2/context_test.go b/drivers/mongov2/context_test.go index 822a589..02e85b5 100644 --- a/drivers/mongov2/context_test.go +++ b/drivers/mongov2/context_test.go @@ -4,12 +4,14 @@ package mongov2 import ( "context" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" "testing" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" + + "github.com/stretchr/testify/require" + "github.com/avito-tech/go-transaction-manager/trm/v2/manager" "github.com/avito-tech/go-transaction-manager/trm/v2/settings" - "github.com/stretchr/testify/require" ) func TestContext(t *testing.T) { diff --git a/drivers/mongov2/contract.go b/drivers/mongov2/contract.go index 0a51737..f87e58a 100644 --- a/drivers/mongov2/contract.go +++ b/drivers/mongov2/contract.go @@ -2,6 +2,7 @@ package mongov2 import ( "context" + "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readpref" diff --git a/drivers/mongov2/example_test.go b/drivers/mongov2/example_test.go index 7bfec42..18e4ca4 100644 --- a/drivers/mongov2/example_test.go +++ b/drivers/mongov2/example_test.go @@ -5,12 +5,14 @@ package mongov2_test import ( "context" "fmt" - trmmongo "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2" - trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" + trmmongo "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2" + trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" + "github.com/avito-tech/go-transaction-manager/trm/v2/manager" ) diff --git a/drivers/mongov2/internal/bsonutil/bsonutil.go b/drivers/mongov2/internal/bsonutil/bsonutil.go new file mode 100644 index 0000000..1eba9c2 --- /dev/null +++ b/drivers/mongov2/internal/bsonutil/bsonutil.go @@ -0,0 +1,62 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsonutil + +import ( + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +// StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value +// is not an array or any of the elements in the array are not strings. The name parameter is used to add context to +// error messages. +func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) { + arr, ok := val.ArrayOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, val.Type) + } + + arrayValues, err := arr.Values() + if err != nil { + return nil, err + } + + strs := make([]string, 0, len(arrayValues)) + for _, arrayVal := range arrayValues { + str, ok := arrayVal.StringValueOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, arrayVal.Type) + } + strs = append(strs, str) + } + return strs, nil +} + +// RawArrayToDocuments converts an array of documents to []bson.Raw. +func RawArrayToDocuments(arr bson.RawArray) []bson.Raw { + values, err := arr.Values() + if err != nil { + panic(fmt.Sprintf("error converting BSON document to values: %v", err)) + } + + out := make([]bson.Raw, len(values)) + for i := range values { + out[i] = values[i].Document() + } + + return out +} + +// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}. +func RawToInterfaces(docs ...bson.Raw) []interface{} { + out := make([]interface{}, len(docs)) + for i := range docs { + out[i] = docs[i] + } + return out +} diff --git a/drivers/mongov2/internal/driverutil/description.go b/drivers/mongov2/internal/driverutil/description.go index df3adc3..2849f7b 100644 --- a/drivers/mongov2/internal/driverutil/description.go +++ b/drivers/mongov2/internal/driverutil/description.go @@ -11,10 +11,11 @@ import ( "fmt" "time" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/bsonutil" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/handshake" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/ptrutil" + "go.mongodb.org/mongo-driver/v2/bson" - "go.mongodb.org/mongo-driver/v2/internal/bsonutil" - "go.mongodb.org/mongo-driver/v2/internal/handshake" - "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/tag" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" diff --git a/drivers/mongov2/internal/handshake/handshake.go b/drivers/mongov2/internal/handshake/handshake.go new file mode 100644 index 0000000..c9537d3 --- /dev/null +++ b/drivers/mongov2/internal/handshake/handshake.go @@ -0,0 +1,13 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package handshake + +// LegacyHello is the legacy version of the hello command. +var LegacyHello = "isMaster" + +// LegacyHelloLowercase is the lowercase, legacy version of the hello command. +var LegacyHelloLowercase = "ismaster" diff --git a/drivers/mongov2/internal/integtest/integtest.go b/drivers/mongov2/internal/integtest/integtest.go index 85a567e..f68de11 100644 --- a/drivers/mongov2/internal/integtest/integtest.go +++ b/drivers/mongov2/internal/integtest/integtest.go @@ -18,8 +18,6 @@ import ( "sync" "testing" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/serverselector" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" @@ -27,6 +25,9 @@ import ( "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/serverselector" ) var connectionString *connstring.ConnString diff --git a/drivers/mongov2/internal/mtest/global_state.go b/drivers/mongov2/internal/mtest/global_state.go index 1f0d0a9..7724c0f 100644 --- a/drivers/mongov2/internal/mtest/global_state.go +++ b/drivers/mongov2/internal/mtest/global_state.go @@ -10,11 +10,12 @@ import ( "context" "fmt" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" ) // AuthEnabled returns whether or not the cluster requires auth. diff --git a/drivers/mongov2/internal/mtest/mongotest.go b/drivers/mongov2/internal/mtest/mongotest.go index f510831..c471430 100644 --- a/drivers/mongov2/internal/mtest/mongotest.go +++ b/drivers/mongov2/internal/mtest/mongotest.go @@ -15,11 +15,6 @@ import ( "sync/atomic" "testing" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/csfle" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mongoutil" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/mongo" @@ -30,6 +25,12 @@ import ( "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/csfle" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/failpoint" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mongoutil" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" ) var ( diff --git a/drivers/mongov2/internal/mtest/setup.go b/drivers/mongov2/internal/mtest/setup.go index f2a0d30..87e4bcc 100644 --- a/drivers/mongov2/internal/mtest/setup.go +++ b/drivers/mongov2/internal/mtest/setup.go @@ -16,7 +16,6 @@ import ( "strings" "time" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" @@ -26,6 +25,8 @@ import ( "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/integtest" ) const ( diff --git a/drivers/mongov2/internal/ptrutil/int64.go b/drivers/mongov2/internal/ptrutil/int64.go new file mode 100644 index 0000000..1c3ab57 --- /dev/null +++ b/drivers/mongov2/internal/ptrutil/int64.go @@ -0,0 +1,39 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package ptrutil + +// CompareInt64 is a piecewise function with the following return conditions: +// +// (1) 2, ptr1 != nil AND ptr2 == nil +// (2) 1, *ptr1 > *ptr2 +// (3) 0, ptr1 == ptr2 or *ptr1 == *ptr2 +// (4) -1, *ptr1 < *ptr2 +// (5) -2, ptr1 == nil AND ptr2 != nil +func CompareInt64(ptr1, ptr2 *int64) int { + if ptr1 == ptr2 { + // This will catch the double nil or same-pointer cases. + return 0 + } + + if ptr1 == nil && ptr2 != nil { + return -2 + } + + if ptr1 != nil && ptr2 == nil { + return 2 + } + + if *ptr1 > *ptr2 { + return 1 + } + + if *ptr1 < *ptr2 { + return -1 + } + + return 0 +} diff --git a/drivers/mongov2/internal/ptrutil/int64_test.go b/drivers/mongov2/internal/ptrutil/int64_test.go new file mode 100644 index 0000000..4d90c84 --- /dev/null +++ b/drivers/mongov2/internal/ptrutil/int64_test.go @@ -0,0 +1,76 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package ptrutil + +import ( + "testing" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" +) + +func TestCompareInt64(t *testing.T) { + t.Parallel() + + int64ToPtr := func(i64 int64) *int64 { return &i64 } + int64Ptr := int64ToPtr(1) + + tests := []struct { + name string + ptr1, ptr2 *int64 + want int + }{ + { + name: "empty", + want: 0, + }, + { + name: "ptr1 nil", + ptr2: int64ToPtr(1), + want: -2, + }, + { + name: "ptr2 nil", + ptr1: int64ToPtr(1), + want: 2, + }, + { + name: "ptr1 and ptr2 have same value, different address", + ptr1: int64ToPtr(1), + ptr2: int64ToPtr(1), + want: 0, + }, + { + name: "ptr1 and ptr2 have the same address", + ptr1: int64Ptr, + ptr2: int64Ptr, + want: 0, + }, + { + name: "ptr1 GT ptr2", + ptr1: int64ToPtr(1), + ptr2: int64ToPtr(0), + want: 1, + }, + { + name: "ptr1 LT ptr2", + ptr1: int64ToPtr(0), + ptr2: int64ToPtr(1), + want: -1, + }, + } + + for _, test := range tests { + test := test // capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got := CompareInt64(test.ptr1, test.ptr2) + assert.Equal(t, test.want, got, "compareInt64() = %v, wanted %v", got, test.want) + }) + } +} diff --git a/drivers/mongov2/internal/ptrutil/ptr.go b/drivers/mongov2/internal/ptrutil/ptr.go new file mode 100644 index 0000000..bf64aad --- /dev/null +++ b/drivers/mongov2/internal/ptrutil/ptr.go @@ -0,0 +1,12 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package ptrutil + +// Ptr will return the memory location of the given value. +func Ptr[T any](val T) *T { + return &val +} diff --git a/drivers/mongov2/internal/serverselector/server_selector_test.go b/drivers/mongov2/internal/serverselector/server_selector_test.go index ec26853..c82edb9 100644 --- a/drivers/mongov2/internal/serverselector/server_selector_test.go +++ b/drivers/mongov2/internal/serverselector/server_selector_test.go @@ -13,16 +13,17 @@ import ( "testing" "time" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/driverutil" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/spectest" "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/tag" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/assert" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/driverutil" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/require" + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/spectest" ) type lastWriteDate struct { diff --git a/drivers/mongov2/settings.go b/drivers/mongov2/settings.go index e4aaff4..19b9e86 100644 --- a/drivers/mongov2/settings.go +++ b/drivers/mongov2/settings.go @@ -1,8 +1,9 @@ package mongov2 import ( - trm "github.com/avito-tech/go-transaction-manager/trm/v2" "go.mongodb.org/mongo-driver/v2/mongo/options" + + trm "github.com/avito-tech/go-transaction-manager/trm/v2" ) // Opt is a type to configure Settings. diff --git a/drivers/mongov2/settings_test.go b/drivers/mongov2/settings_test.go index 425a19b..2a4dd86 100644 --- a/drivers/mongov2/settings_test.go +++ b/drivers/mongov2/settings_test.go @@ -1,9 +1,10 @@ package mongov2 import ( + "testing" + "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" - "testing" "github.com/stretchr/testify/assert" diff --git a/drivers/mongov2/transaction.go b/drivers/mongov2/transaction.go index 0787e6b..96d5c22 100644 --- a/drivers/mongov2/transaction.go +++ b/drivers/mongov2/transaction.go @@ -3,6 +3,7 @@ package mongov2 import ( "context" + "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" diff --git a/drivers/mongov2/transaction_test.go b/drivers/mongov2/transaction_test.go index b4bd8ce..1917772 100644 --- a/drivers/mongov2/transaction_test.go +++ b/drivers/mongov2/transaction_test.go @@ -5,14 +5,16 @@ package mongov2 import ( "context" "errors" - "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" + "sync" + "testing" + "time" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" - "sync" - "testing" - "time" + + "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 48e2ae80b68053e3b063cfd5e6426626ffa72dc8 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Thu, 17 Jul 2025 11:00:11 +0800 Subject: [PATCH 6/7] feat:golangci issues excludeadd mongov2/internal --- .golangci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.golangci.yml b/.golangci.yml index 822b3a3..b8528fb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -63,6 +63,7 @@ issues: exclude-dirs: - trm/manager/mock - sql/mock + - drivers/mongov2/internal exclude-use-default: false exclude: - ST1000 # ST1000: at least one file in a package should have a package comment From e9fd7adf1257cbf6348e87970ad86a8fc2bde037 Mon Sep 17 00:00:00 2001 From: hutiquan Date: Wed, 30 Jul 2025 13:37:02 +0800 Subject: [PATCH 7/7] resolve the issue of non-standard package imports. --- drivers/mongov2/context.go | 1 - drivers/mongov2/context_test.go | 3 +-- drivers/mongov2/contract.go | 3 +-- drivers/mongov2/example_test.go | 1 - drivers/mongov2/settings_test.go | 3 +-- 5 files changed, 3 insertions(+), 8 deletions(-) diff --git a/drivers/mongov2/context.go b/drivers/mongov2/context.go index bd3a3fa..d2070df 100644 --- a/drivers/mongov2/context.go +++ b/drivers/mongov2/context.go @@ -6,7 +6,6 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo" "github.com/avito-tech/go-transaction-manager/trm/v2" - trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" ) diff --git a/drivers/mongov2/context_test.go b/drivers/mongov2/context_test.go index 02e85b5..424120c 100644 --- a/drivers/mongov2/context_test.go +++ b/drivers/mongov2/context_test.go @@ -8,10 +8,9 @@ import ( "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2/internal/mtest" - "github.com/stretchr/testify/require" - "github.com/avito-tech/go-transaction-manager/trm/v2/manager" "github.com/avito-tech/go-transaction-manager/trm/v2/settings" + "github.com/stretchr/testify/require" ) func TestContext(t *testing.T) { diff --git a/drivers/mongov2/contract.go b/drivers/mongov2/contract.go index f87e58a..ad7660b 100644 --- a/drivers/mongov2/contract.go +++ b/drivers/mongov2/contract.go @@ -3,10 +3,9 @@ package mongov2 import ( "context" + "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readpref" - - "go.mongodb.org/mongo-driver/v2/mongo" ) //nolint:interfacebloat diff --git a/drivers/mongov2/example_test.go b/drivers/mongov2/example_test.go index 18e4ca4..578fb6d 100644 --- a/drivers/mongov2/example_test.go +++ b/drivers/mongov2/example_test.go @@ -12,7 +12,6 @@ import ( trmmongo "github.com/avito-tech/go-transaction-manager/drivers/mongov2/v2" trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context" - "github.com/avito-tech/go-transaction-manager/trm/v2/manager" ) diff --git a/drivers/mongov2/settings_test.go b/drivers/mongov2/settings_test.go index 2a4dd86..811250e 100644 --- a/drivers/mongov2/settings_test.go +++ b/drivers/mongov2/settings_test.go @@ -6,10 +6,9 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" - "github.com/stretchr/testify/assert" - "github.com/avito-tech/go-transaction-manager/trm/v2" "github.com/avito-tech/go-transaction-manager/trm/v2/settings" + "github.com/stretchr/testify/assert" ) func TestSettings_EnrichBy(t *testing.T) {