Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions internal/sources/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ package mongodb

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -93,6 +96,201 @@ func (s *Source) MongoClient() *mongo.Client {
return s.Client
}

func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
var data = []any{}
err := cur.All(ctx, &data)
if err != nil {
return nil, err
}
var final []any
for _, item := range data {
tmp, _ := bson.MarshalExtJSON(item, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
}
return final, err
}

func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
var pipeline = []bson.M{}
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
if err != nil {
return nil, err
}

if readOnly {
//fail if we do a merge or an out
for _, stage := range pipeline {
for key := range stage {
if key == "$merge" || key == "$out" {
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
}
}
}
}

cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
res, err := parseData(ctx, cur)
if err != nil {
return nil, err
}
if res == nil {
return []any{}, nil
}
return res, err
}

func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
return parseData(ctx, cur)
}

func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
if res.Err() != nil {
return nil, res.Err()
}

var data any
err = res.Decode(&data)
if err != nil {
return nil, err
}

var final []any
tmp, _ := bson.MarshalExtJSON(data, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)

return final, err
}

func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
var data = []any{}
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}

res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
if err != nil {
return nil, err
}
return res.InsertedIDs, nil
}

func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
var data any
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}

res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
if err != nil {
return nil, err
}
return res.InsertedID, nil
}

func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}

res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
}

func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}

res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return res.ModifiedCount, nil
}

func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
if err != nil {
return nil, err
}

if res.DeletedCount == 0 {
return nil, errors.New("no document found")
}
return res.DeletedCount, nil
}

func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
if err != nil {
return nil, err
}
return res.DeletedCount, nil
}

func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
// Start a tracing span
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
Expand Down
51 changes: 2 additions & 49 deletions internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ package mongodbaggregate

import (
"context"
"encoding/json"
"fmt"
"slices"

"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"

"github.com/googleapis/genai-toolbox/internal/sources"
Expand All @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T

type compatibleSource interface {
MongoClient() *mongo.Client
Aggregate(context.Context, string, bool, bool, string, string) ([]any, error)
}

type Config struct {
Expand Down Expand Up @@ -110,57 +109,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}

paramsMap := params.AsMap()

pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating pipeline: %s", err)
}

var pipeline = []bson.M{}
err = bson.UnmarshalExtJSON([]byte(pipelineString), t.Canonical, &pipeline)
if err != nil {
return nil, err
}

if t.ReadOnly {
//fail if we do a merge or an out
for _, stage := range pipeline {
for key := range stage {
if key == "$merge" || key == "$out" {
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
}
}
}
}

cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cur.Close(ctx)

var data = []any{}
err = cur.All(ctx, &data)
if err != nil {
return nil, err
}

if len(data) == 0 {
return []any{}, nil
}

var final []any
for _, item := range data {
tmp, _ := bson.MarshalExtJSON(item, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
}

return final, err
return source.Aggregate(ctx, pipelineString, t.Canonical, t.ReadOnly, t.Database, t.Collection)
}

func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
Expand Down
26 changes: 2 additions & 24 deletions internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ package mongodbdeletemany

import (
"context"
"errors"
"fmt"
"slices"

"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"

"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
Expand All @@ -48,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T

type compatibleSource interface {
MongoClient() *mongo.Client
DeleteMany(context.Context, string, string, string) (any, error)
}

type Config struct {
Expand Down Expand Up @@ -115,31 +113,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}

paramsMap := params.AsMap()

filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating filter: %s", err)
}

opts := options.Delete()

var filter = bson.D{}
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteMany(ctx, filter, opts)
if err != nil {
return nil, err
}

if res.DeletedCount == 0 {
return nil, errors.New("no document found")
}

// not much to return actually
return res.DeletedCount, nil
return source.DeleteMany(ctx, filterString, t.Database, t.Collection)
}

func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
Expand Down
20 changes: 2 additions & 18 deletions internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"

"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
Expand All @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T

type compatibleSource interface {
MongoClient() *mongo.Client
DeleteOne(context.Context, string, string, string) (any, error)
}

type Config struct {
Expand Down Expand Up @@ -119,22 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, fmt.Errorf("error populating filter: %s", err)
}

opts := options.Delete()

var filter = bson.D{}
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}

res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteOne(ctx, filter, opts)
if err != nil {
return nil, err
}

// do not return an error when the count is 0, to mirror the delete many call result
return res.DeletedCount, nil
return source.DeleteOne(ctx, filterString, t.Database, t.Collection)
}

func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
Expand Down
Loading
Loading