From 34d7df73ee487e05f599137fc7d862baee7c15af Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Wed, 7 Jan 2026 22:38:20 -0800 Subject: [PATCH] refactor(sources/mongodb): move source implementation in Invoke() function into Source --- internal/sources/mongodb/mongodb.go | 198 ++++++++++++++++++ .../mongodbaggregate/mongodbaggregate.go | 51 +---- .../mongodbdeletemany/mongodbdeletemany.go | 26 +-- .../mongodbdeleteone/mongodbdeleteone.go | 20 +- .../tools/mongodb/mongodbfind/mongodbfind.go | 37 +--- .../mongodb/mongodbfindone/mongodbfindone.go | 33 +-- .../mongodbinsertmany/mongodbinsertmany.go | 17 +- .../mongodbinsertone/mongodbinsertone.go | 18 +- .../mongodbupdatemany/mongodbupdatemany.go | 25 +-- .../mongodbupdateone/mongodbupdateone.go | 25 +-- 10 files changed, 216 insertions(+), 234 deletions(-) diff --git a/internal/sources/mongodb/mongodb.go b/internal/sources/mongodb/mongodb.go index 035788f35aa..533e5005d58 100644 --- a/internal/sources/mongodb/mongodb.go +++ b/internal/sources/mongodb/mongodb.go @@ -16,11 +16,14 @@ package mongodb import ( "context" + "encoding/json" + "errors" "fmt" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.opentelemetry.io/otel/trace" @@ -93,6 +96,201 @@ func (s *Source) MongoClient() *mongo.Client { return s.Client } +func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) { + var data = []any{} + err := cur.All(ctx, &data) + if err != nil { + return nil, err + } + var final []any + for _, item := range data { + tmp, _ := bson.MarshalExtJSON(item, false, false) + var tmp2 any + err = json.Unmarshal(tmp, &tmp2) + if err != nil { + return nil, err + } + final = append(final, tmp2) + } + return final, err +} + +func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) { + var pipeline = []bson.M{} + err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline) + if err != nil { + return nil, err + } + + if readOnly { + //fail if we do a merge or an out + for _, stage := range pipeline { + for key := range stage { + if key == "$merge" || key == "$out" { + return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage) + } + } + } + } + + cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + res, err := parseData(ctx, cur) + if err != nil { + return nil, err + } + if res == nil { + return []any{}, nil + } + return res, err +} + +func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter) + if err != nil { + return nil, err + } + + cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + return parseData(ctx, cur) +} + +func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter) + if err != nil { + return nil, err + } + + res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts) + if res.Err() != nil { + return nil, res.Err() + } + + var data any + err = res.Decode(&data) + if err != nil { + return nil, err + } + + var final []any + tmp, _ := bson.MarshalExtJSON(data, false, false) + var tmp2 any + err = json.Unmarshal(tmp, &tmp2) + if err != nil { + return nil, err + } + final = append(final, tmp2) + + return final, err +} + +func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) { + var data = []any{} + err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data) + if err != nil { + return nil, err + } + + res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany()) + if err != nil { + return nil, err + } + return res.InsertedIDs, nil +} + +func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) { + var data any + err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data) + if err != nil { + return nil, err + } + + res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne()) + if err != nil { + return nil, err + } + return res.InsertedID, nil +} + +func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal filter string: %w", err) + } + var update = bson.D{} + err = bson.UnmarshalExtJSON([]byte(updateString), false, &update) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal update string: %w", err) + } + + res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert)) + if err != nil { + return nil, fmt.Errorf("error updating collection: %w", err) + } + return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil +} + +func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal filter string: %w", err) + } + var update = bson.D{} + err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update) + if err != nil { + return nil, fmt.Errorf("unable to unmarshal update string: %w", err) + } + + res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert)) + if err != nil { + return nil, fmt.Errorf("error updating collection: %w", err) + } + return res.ModifiedCount, nil +} + +func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter) + if err != nil { + return nil, err + } + + res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete()) + if err != nil { + return nil, err + } + + if res.DeletedCount == 0 { + return nil, errors.New("no document found") + } + return res.DeletedCount, nil +} + +func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) { + var filter = bson.D{} + err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter) + if err != nil { + return nil, err + } + + res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete()) + if err != nil { + return nil, err + } + return res.DeletedCount, nil +} + func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) { // Start a tracing span ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go index 2fa313b8833..00bf5641aa4 100644 --- a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go +++ b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go @@ -15,14 +15,12 @@ package mongodbaggregate import ( "context" - "encoding/json" "fmt" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "github.com/googleapis/genai-toolbox/internal/sources" @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + Aggregate(context.Context, string, bool, bool, string, string) ([]any, error) } type Config struct { @@ -110,57 +109,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap) if err != nil { return nil, fmt.Errorf("error populating pipeline: %s", err) } - - var pipeline = []bson.M{} - err = bson.UnmarshalExtJSON([]byte(pipelineString), t.Canonical, &pipeline) - if err != nil { - return nil, err - } - - if t.ReadOnly { - //fail if we do a merge or an out - for _, stage := range pipeline { - for key := range stage { - if key == "$merge" || key == "$out" { - return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage) - } - } - } - } - - cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Aggregate(ctx, pipeline) - if err != nil { - return nil, err - } - defer cur.Close(ctx) - - var data = []any{} - err = cur.All(ctx, &data) - if err != nil { - return nil, err - } - - if len(data) == 0 { - return []any{}, nil - } - - var final []any - for _, item := range data { - tmp, _ := bson.MarshalExtJSON(item, false, false) - var tmp2 any - err = json.Unmarshal(tmp, &tmp2) - if err != nil { - return nil, err - } - final = append(final, tmp2) - } - - return final, err + return source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go index ab62bdb2f5b..29a2a334950 100644 --- a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go +++ b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go @@ -15,16 +15,13 @@ package mongodbdeletemany import ( "context" - "errors" "fmt" "slices" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -48,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + DeleteMany(context.Context, string, string, string) (any, error) } type Config struct { @@ -115,31 +113,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap) if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } - - opts := options.Delete() - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter) - if err != nil { - return nil, err - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteMany(ctx, filter, opts) - if err != nil { - return nil, err - } - - if res.DeletedCount == 0 { - return nil, errors.New("no document found") - } - - // not much to return actually - return res.DeletedCount, nil + return source.DeleteMany(ctx, filterString, t.Database, t.Collection) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go index 0e494f73c7e..2d761d83ede 100644 --- a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go +++ b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go @@ -21,9 +21,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + DeleteOne(context.Context, string, string, string) (any, error) } type Config struct { @@ -119,22 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } - - opts := options.Delete() - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter) - if err != nil { - return nil, err - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteOne(ctx, filter, opts) - if err != nil { - return nil, err - } - - // do not return an error when the count is 0, to mirror the delete many call result - return res.DeletedCount, nil + return source.DeleteOne(ctx, filterString, t.Database, t.Collection) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfind/mongodbfind.go b/internal/tools/mongodb/mongodbfind/mongodbfind.go index e447bb15a0a..12ccd846361 100644 --- a/internal/tools/mongodb/mongodbfind/mongodbfind.go +++ b/internal/tools/mongodb/mongodbfind/mongodbfind.go @@ -15,7 +15,6 @@ package mongodbfind import ( "context" - "encoding/json" "fmt" "slices" @@ -49,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + Find(context.Context, string, string, string, *options.FindOptions) ([]any, error) } type Config struct { @@ -164,48 +164,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap) - if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } - opts, err := getOptions(ctx, t.SortParams, t.ProjectPayload, t.Limit, paramsMap) if err != nil { return nil, fmt.Errorf("error populating options: %s", err) } - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter) - if err != nil { - return nil, err - } - - cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Find(ctx, filter, opts) - if err != nil { - return nil, err - } - defer cur.Close(ctx) - - var data = []any{} - err = cur.All(context.TODO(), &data) - if err != nil { - return nil, err - } - - var final []any - for _, item := range data { - tmp, _ := bson.MarshalExtJSON(item, false, false) - var tmp2 any - err = json.Unmarshal(tmp, &tmp2) - if err != nil { - return nil, err - } - final = append(final, tmp2) - } - - return final, err + return source.Find(ctx, filterString, t.Database, t.Collection, opts) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go index 6fa537a635a..e9f1555c669 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go @@ -15,7 +15,6 @@ package mongodbfindone import ( "context" - "encoding/json" "fmt" "slices" @@ -48,6 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + FindOne(context.Context, string, string, string, *options.FindOneOptions) ([]any, error) } type Config struct { @@ -117,9 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap) - if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } @@ -137,34 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } opts = opts.SetProjection(projection) } - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter) - if err != nil { - return nil, err - } - - res := source.MongoClient().Database(t.Database).Collection(t.Collection).FindOne(ctx, filter, opts) - if res.Err() != nil { - return nil, res.Err() - } - - var data any - err = res.Decode(&data) - if err != nil { - return nil, err - } - - var final []any - tmp, _ := bson.MarshalExtJSON(data, false, false) - var tmp2 any - err = json.Unmarshal(tmp, &tmp2) - if err != nil { - return nil, err - } - final = append(final, tmp2) - - return final, err + return source.FindOne(ctx, filterString, t.Database, t.Collection, opts) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go index 0cbaca3c0db..0dc11a230b6 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go @@ -23,9 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) const kind string = "mongodb-insert-many" @@ -48,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + InsertMany(context.Context, string, bool, string, string) ([]any, error) } type Config struct { @@ -117,19 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, errors.New("no input found") } - - var data = []any{} - err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) - if err != nil { - return nil, err - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertMany(ctx, data, options.InsertMany()) - if err != nil { - return nil, err - } - - return res.InsertedIDs, nil + return source.InsertMany(ctx, jsonData, t.Canonical, t.Database, t.Collection) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go index 23e2928a871..a93589c053b 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go @@ -23,9 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) const kind string = "mongodb-insert-one" @@ -48,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + InsertOne(context.Context, string, bool, string, string) (any, error) } type Config struct { @@ -107,7 +106,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - if len(params) == 0 { return nil, errors.New("no input found") } @@ -116,19 +114,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, errors.New("no input found") } - - var data any - err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) - if err != nil { - return nil, err - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertOne(ctx, data, options.InsertOne()) - if err != nil { - return nil, err - } - - return res.InsertedID, nil + return source.InsertOne(ctx, jsonData, t.Canonical, t.Database, t.Collection) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go index 9dcadc66eff..b80bc4972e4 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go @@ -23,9 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) const kind string = "mongodb-update-many" @@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + UpdateMany(context.Context, string, bool, string, string, string, bool) ([]any, error) } type Config struct { @@ -117,35 +116,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap) if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), t.Canonical, &filter) - if err != nil { - return nil, fmt.Errorf("unable to unmarshal filter string: %w", err) - } - updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateMany", t.UpdatePayload, paramsMap) if err != nil { return nil, fmt.Errorf("unable to get update: %w", err) } - - var update = bson.D{} - err = bson.UnmarshalExtJSON([]byte(updateString), false, &update) - if err != nil { - return nil, fmt.Errorf("unable to unmarshal update string: %w", err) - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) - if err != nil { - return nil, fmt.Errorf("error updating collection: %w", err) - } - - return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil + return source.UpdateMany(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go index 11bbe2ac164..d3236992e84 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go @@ -23,9 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) const kind string = "mongodb-update-one" @@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MongoClient() *mongo.Client + UpdateOne(context.Context, string, bool, string, string, string, bool) (any, error) } type Config struct { @@ -118,35 +117,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap) if err != nil { return nil, fmt.Errorf("error populating filter: %s", err) } - - var filter = bson.D{} - err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter) - if err != nil { - return nil, fmt.Errorf("unable to unmarshal filter string: %w", err) - } - updateString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOne", t.UpdatePayload, paramsMap) if err != nil { return nil, fmt.Errorf("unable to get update: %w", err) } - - var update = bson.D{} - err = bson.UnmarshalExtJSON([]byte(updateString), t.Canonical, &update) - if err != nil { - return nil, fmt.Errorf("unable to unmarshal update string: %w", err) - } - - res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) - if err != nil { - return nil, fmt.Errorf("error updating collection: %w", err) - } - - return res.ModifiedCount, nil + return source.UpdateOne(ctx, filterString, t.Canonical, updateString, t.Database, t.Collection, t.Upsert) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {