From f00b4a740e6d1aec0d60a5cce71349272dbdb299 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Thu, 4 Jul 2024 16:31:36 +0200 Subject: [PATCH] contrib/gorm.io/gorm.v1: propagate parent span (#2759) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michel Chretien Co-authored-by: Michel Chrétien --- contrib/gorm.io/gorm.v1/example_test.go | 26 +++++ contrib/gorm.io/gorm.v1/gorm.go | 96 +++++++++++------- contrib/gorm.io/gorm.v1/gorm_test.go | 124 ++++++++++++++++++++++-- 3 files changed, 201 insertions(+), 45 deletions(-) diff --git a/contrib/gorm.io/gorm.v1/example_test.go b/contrib/gorm.io/gorm.v1/example_test.go index ec45b76a00..2b75405503 100644 --- a/contrib/gorm.io/gorm.v1/example_test.go +++ b/contrib/gorm.io/gorm.v1/example_test.go @@ -7,6 +7,7 @@ package gorm_test import ( "context" + "errors" "log" sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" @@ -41,6 +42,31 @@ func ExampleOpen() { db.Where("name = ?", "jinzhu").First(&user) } +// ExampleNewTracePlugin illustrates how to trace gorm using the gorm.Plugin api. +func ExampleNewTracePlugin() { + // Register augments the provided driver with tracing, enabling it to be loaded by gorm.Open and the gormtrace.TracePlugin. + sqltrace.Register("pgx", &stdlib.Driver{}, sqltrace.WithServiceName("my-service")) + sqlDb, err := sqltrace.Open("pgx", "postgres://pqgotest:password@localhost/pqgotest?sslmode=disable") + if err != nil { + log.Fatal(err) + } + db, err := gorm.Open(postgres.New(postgres.Config{Conn: sqlDb}), &gorm.Config{}) + if err != nil { + log.Fatal(err) + } + var user User + + errCheck := gormtrace.WithErrorCheck(func(err error) bool { + return !errors.Is(err, gorm.ErrRecordNotFound) + }) + if err := db.Use(gormtrace.NewTracePlugin(errCheck)); err != nil { + log.Fatal(err) + } + + // All calls through gorm.DB are now traced. + db.Where("name = ?", "jinzhu").First(&user) +} + func ExampleContext() { // Register augments the provided driver with tracing, enabling it to be loaded by gormtrace.Open. sqltrace.Register("pgx", &stdlib.Driver{}, sqltrace.WithServiceName("my-service")) diff --git a/contrib/gorm.io/gorm.v1/gorm.go b/contrib/gorm.io/gorm.v1/gorm.go index 0444065114..b87b8b79a0 100644 --- a/contrib/gorm.io/gorm.v1/gorm.go +++ b/contrib/gorm.io/gorm.v1/gorm.go @@ -7,9 +7,7 @@ package gorm import ( - "context" "math" - "time" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" @@ -33,6 +31,26 @@ const ( gormSpanStartTimeKey = key("dd-trace-go:span") ) +type tracePlugin struct { + options []Option +} + +// NewTracePlugin returns a new gorm.Plugin that enhances the underlying *gorm.DB with tracing. +func NewTracePlugin(opts ...Option) gorm.Plugin { + return tracePlugin{ + options: opts, + } +} + +func (tracePlugin) Name() string { + return "DDTracePlugin" +} + +func (g tracePlugin) Initialize(db *gorm.DB) error { + _, err := withCallbacks(db, g.options...) + return err +} + // Open opens a new (traced) database connection. The used driver must be formerly registered // using (gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql).Register. func Open(dialector gorm.Dialector, cfg *gorm.Config, opts ...Option) (*gorm.DB, error) { @@ -57,86 +75,80 @@ func withCallbacks(db *gorm.DB, opts ...Option) (*gorm.DB, error) { } log.Debug("Registering Callbacks: %#v", cfg) - afterFunc := func(operationName string) func(*gorm.DB) { + afterFunc := func() func(*gorm.DB) { + return func(db *gorm.DB) { + after(db, cfg) + } + } + + beforeFunc := func(operationName string) func(*gorm.DB) { return func(db *gorm.DB) { - after(db, operationName, cfg) + before(db, operationName, cfg) } } cb := db.Callback() - err := cb.Create().Before("gorm:create").Register("dd-trace-go:before_create", before) + err := cb.Create().Before("gorm:create").Register("dd-trace-go:before_create", beforeFunc("gorm.create")) if err != nil { return db, err } - err = cb.Create().After("gorm:create").Register("dd-trace-go:after_create", afterFunc("gorm.create")) + err = cb.Create().After("gorm:create").Register("dd-trace-go:after_create", afterFunc()) if err != nil { return db, err } - err = cb.Update().Before("gorm:update").Register("dd-trace-go:before_update", before) + err = cb.Update().Before("gorm:update").Register("dd-trace-go:before_update", beforeFunc("gorm.update")) if err != nil { return db, err } - err = cb.Update().After("gorm:update").Register("dd-trace-go:after_update", afterFunc("gorm.update")) + err = cb.Update().After("gorm:update").Register("dd-trace-go:after_update", afterFunc()) if err != nil { return db, err } - err = cb.Delete().Before("gorm:delete").Register("dd-trace-go:before_delete", before) + err = cb.Delete().Before("gorm:delete").Register("dd-trace-go:before_delete", beforeFunc("gorm.delete")) if err != nil { return db, err } - err = cb.Delete().After("gorm:delete").Register("dd-trace-go:after_delete", afterFunc("gorm.delete")) + err = cb.Delete().After("gorm:delete").Register("dd-trace-go:after_delete", afterFunc()) if err != nil { return db, err } - err = cb.Query().Before("gorm:query").Register("dd-trace-go:before_query", before) + err = cb.Query().Before("gorm:query").Register("dd-trace-go:before_query", beforeFunc("gorm.query")) if err != nil { return db, err } - err = cb.Query().After("gorm:query").Register("dd-trace-go:after_query", afterFunc("gorm.query")) + err = cb.Query().After("gorm:query").Register("dd-trace-go:after_query", afterFunc()) if err != nil { return db, err } - err = cb.Row().Before("gorm:query").Register("dd-trace-go:before_row_query", before) + err = cb.Row().Before("gorm:row").Register("dd-trace-go:before_row_query", beforeFunc("gorm.row_query")) if err != nil { return db, err } - err = cb.Row().After("gorm:query").Register("dd-trace-go:after_row_query", afterFunc("gorm.row_query")) + err = cb.Row().After("gorm:row").Register("dd-trace-go:after_row_query", afterFunc()) if err != nil { return db, err } - err = cb.Raw().Before("gorm:query").Register("dd-trace-go:before_raw_query", before) + err = cb.Raw().Before("gorm:raw").Register("dd-trace-go:before_raw_query", beforeFunc("gorm.raw_query")) if err != nil { return db, err } - err = cb.Raw().After("gorm:query").Register("dd-trace-go:after_raw_query", afterFunc("gorm.raw_query")) + err = cb.Raw().After("gorm:raw").Register("dd-trace-go:after_raw_query", afterFunc()) if err != nil { return db, err } return db, nil } -func before(scope *gorm.DB) { - if scope.Statement != nil && scope.Statement.Context != nil { - scope.Statement.Context = context.WithValue(scope.Statement.Context, gormSpanStartTimeKey, time.Now()) - } -} - -func after(db *gorm.DB, operationName string, cfg *config) { +func before(db *gorm.DB, operationName string, cfg *config) { if db.Statement == nil || db.Statement.Context == nil { return } - - ctx := db.Statement.Context - t, ok := ctx.Value(gormSpanStartTimeKey).(time.Time) - if !ok { + if db.Config == nil || db.Config.DryRun { return } - opts := []ddtrace.StartSpanOption{ - tracer.StartTime(t), tracer.ServiceName(cfg.serviceName), tracer.SpanType(ext.SpanTypeSQL), - tracer.ResourceName(db.Statement.SQL.String()), tracer.Tag(ext.Component, componentName), } if !math.IsNaN(cfg.analyticsRate) { @@ -148,10 +160,24 @@ func after(db *gorm.DB, operationName string, cfg *config) { } } - span, _ := tracer.StartSpanFromContext(ctx, operationName, opts...) - var dbErr error - if cfg.errCheck(db.Error) { - dbErr = db.Error + _, ctx := tracer.StartSpanFromContext(db.Statement.Context, operationName, opts...) + db.Statement.Context = ctx +} + +func after(db *gorm.DB, cfg *config) { + if db.Statement == nil || db.Statement.Context == nil { + return + } + if db.Config == nil || db.Config.DryRun { + return + } + span, ok := tracer.SpanFromContext(db.Statement.Context) + if ok { + var dbErr error + if cfg.errCheck(db.Error) { + dbErr = db.Error + } + span.SetTag(ext.ResourceName, db.Statement.SQL.String()) + span.Finish(tracer.WithError(dbErr)) } - span.Finish(tracer.WithError(dbErr)) } diff --git a/contrib/gorm.io/gorm.v1/gorm_test.go b/contrib/gorm.io/gorm.v1/gorm_test.go index cb1a11f95b..8dac64bfcf 100644 --- a/contrib/gorm.io/gorm.v1/gorm_test.go +++ b/contrib/gorm.io/gorm.v1/gorm_test.go @@ -24,10 +24,12 @@ import ( _ "github.com/lib/pq" mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" mysqlgorm "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlserver" "gorm.io/gorm" + "gorm.io/gorm/utils/tests" ) // tableName holds the SQL table that these tests will be run against. It must be unique cross-repo. @@ -172,10 +174,6 @@ type Product struct { } func TestCallbacks(t *testing.T) { - a := assert.New(t) - mt := mocktracer.Start() - defer mt.Stop() - sqltrace.Register("pgx", &stdlib.Driver{}) sqlDb, err := sqltrace.Open("pgx", pgConnString) if err != nil { @@ -193,12 +191,15 @@ func TestCallbacks(t *testing.T) { } t.Run("create", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", tracer.ServiceName("fake-http-server"), tracer.SpanType(ext.SpanTypeWeb), ) - db = db.WithContext(ctx) + db := db.WithContext(ctx) var queryText string db.Callback().Create().After("testing").Register("query text", func(d *gorm.DB) { queryText = d.Statement.SQL.String() @@ -215,15 +216,26 @@ func TestCallbacks(t *testing.T) { a.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) a.Equal(queryText, span.Tag(ext.ResourceName)) a.Equal("gorm.io/gorm.v1", span.Tag(ext.Component)) + a.Equal(parentSpan.Context().SpanID(), span.ParentID()) + + for _, s := range spans { + if s.Tag(ext.Component) == "jackc/pgx.v5" { + // The underlying driver should receive the gorm span + a.Equal(span.SpanID(), s.ParentID()) + } + } }) t.Run("query", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", tracer.ServiceName("fake-http-server"), tracer.SpanType(ext.SpanTypeWeb), ) - db = db.WithContext(ctx) + db := db.WithContext(ctx) var queryText string db.Callback().Query().After("testing").Register("query text", func(d *gorm.DB) { queryText = d.Statement.SQL.String() @@ -241,15 +253,46 @@ func TestCallbacks(t *testing.T) { a.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) a.Equal(queryText, span.Tag(ext.ResourceName)) a.Equal("gorm.io/gorm.v1", span.Tag(ext.Component)) + a.Equal(parentSpan.Context().SpanID(), span.ParentID()) + + for _, s := range spans { + if s.Tag(ext.Component) == "jackc/pgx.v5" { + // The underlying driver should receive the gorm span + a.Equal(span.SpanID(), s.ParentID()) + } + } + }) + + t.Run("dry_run", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() + parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", + tracer.ServiceName("fake-http-server"), + tracer.SpanType(ext.SpanTypeWeb), + ) + + db := db.WithContext(ctx) + db.DryRun = true + var product Product + db.First(&product, "code = ?", "L1212") + + parentSpan.Finish() + + spans := mt.FinishedSpans() + a.Len(spans, 1) // No additional span generated }) t.Run("update", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", tracer.ServiceName("fake-http-server"), tracer.SpanType(ext.SpanTypeWeb), ) - db = db.WithContext(ctx) + db := db.WithContext(ctx) var queryText string db.Callback().Update().After("testing").Register("query text", func(d *gorm.DB) { queryText = d.Statement.SQL.String() @@ -268,15 +311,26 @@ func TestCallbacks(t *testing.T) { a.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) a.Equal(queryText, span.Tag(ext.ResourceName)) a.Equal("gorm.io/gorm.v1", span.Tag(ext.Component)) + a.Equal(parentSpan.Context().SpanID(), span.ParentID()) + + for _, s := range spans { + if s.Tag(ext.Component) == "jackc/pgx.v5" { + // The underlying driver should receive the gorm span + a.Equal(span.SpanID(), s.ParentID()) + } + } }) t.Run("delete", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", tracer.ServiceName("fake-http-server"), tracer.SpanType(ext.SpanTypeWeb), ) - db = db.WithContext(ctx) + db := db.WithContext(ctx) var queryText string db.Callback().Delete().After("testing").Register("query text", func(d *gorm.DB) { queryText = d.Statement.SQL.String() @@ -295,15 +349,26 @@ func TestCallbacks(t *testing.T) { a.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) a.Equal(queryText, span.Tag(ext.ResourceName)) a.Equal("gorm.io/gorm.v1", span.Tag(ext.Component)) + a.Equal(parentSpan.Context().SpanID(), span.ParentID()) + + for _, s := range spans { + if s.Tag(ext.Component) == "jackc/pgx.v5" { + // The underlying driver should receive the gorm span + a.Equal(span.SpanID(), s.ParentID()) + } + } }) t.Run("raw", func(t *testing.T) { + a := assert.New(t) + mt := mocktracer.Start() + defer mt.Stop() parentSpan, ctx := tracer.StartSpanFromContext(context.Background(), "http.request", tracer.ServiceName("fake-http-server"), tracer.SpanType(ext.SpanTypeWeb), ) - db = db.WithContext(ctx) + db := db.WithContext(ctx) var queryText string db.Callback().Raw().After("testing").Register("query text", func(d *gorm.DB) { queryText = d.Statement.SQL.String() @@ -321,6 +386,13 @@ func TestCallbacks(t *testing.T) { a.Equal("gorm.raw_query", span.OperationName()) a.Equal(ext.SpanTypeSQL, span.Tag(ext.SpanType)) a.Equal(queryText, span.Tag(ext.ResourceName)) + + for _, s := range spans { + if s.Tag(ext.Component) == "jackc/pgx.v5" { + // The underlying driver should receive the gorm span + a.Equal(span.SpanID(), s.ParentID()) + } + } }) } @@ -487,7 +559,7 @@ func TestCustomTags(t *testing.T) { db, err := Open( postgres.New(postgres.Config{Conn: sqlDb}), &gorm.Config{}, - WithCustomTag("foo", func(db *gorm.DB) interface{} { + WithCustomTag("foo", func(_ *gorm.DB) interface{} { return "bar" }), ) @@ -511,3 +583,35 @@ func TestCustomTags(t *testing.T) { assert.Equal("bar", s.Tag("foo")) } + +func TestPlugin(t *testing.T) { + db, err := gorm.Open(&tests.DummyDialector{}) + require.NoError(t, err) + + opt := WithCustomTag("foo", func(_ *gorm.DB) interface{} { + return "bar" + }) + plugin := NewTracePlugin(opt).(tracePlugin) + + assert.Equal(t, "DDTracePlugin", plugin.Name()) + assert.Len(t, plugin.options, 1) + require.NoError(t, db.Use(plugin)) + + assert.NotNil(t, db.Callback().Create().Get("dd-trace-go:before_create")) + assert.NotNil(t, db.Callback().Create().Get("dd-trace-go:after_create")) + + assert.NotNil(t, db.Callback().Update().Get("dd-trace-go:before_update")) + assert.NotNil(t, db.Callback().Update().Get("dd-trace-go:after_update")) + + assert.NotNil(t, db.Callback().Delete().Get("dd-trace-go:before_delete")) + assert.NotNil(t, db.Callback().Delete().Get("dd-trace-go:after_delete")) + + assert.NotNil(t, db.Callback().Query().Get("dd-trace-go:before_query")) + assert.NotNil(t, db.Callback().Query().Get("dd-trace-go:after_query")) + + assert.NotNil(t, db.Callback().Row().Get("dd-trace-go:before_row_query")) + assert.NotNil(t, db.Callback().Row().Get("dd-trace-go:after_row_query")) + + assert.NotNil(t, db.Callback().Raw().Get("dd-trace-go:before_raw_query")) + assert.NotNil(t, db.Callback().Raw().Get("dd-trace-go:before_raw_query")) +}