From 7a355851437cc62426d853ba4ea9211887a82473 Mon Sep 17 00:00:00 2001 From: nolandseigler <57370691+nolandseigler@users.noreply.github.com> Date: Thu, 11 Jul 2024 22:39:29 -0400 Subject: [PATCH 1/4] example test case that demonstrates snake case collision in db tags caused by rows.go 'fieldPosByName' --- rows_test.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/rows_test.go b/rows_test.go index bb9d50152..5dbf952da 100644 --- a/rows_test.go +++ b/rows_test.go @@ -667,6 +667,41 @@ func TestRowToStructByName(t *testing.T) { }) } +func TestRowToStructByNameDbTags(t *testing.T) { + type person struct { + Last string `db:"last_name"` + First string `db:"first_name"` + Age int32 `db:"age"` + AccountID string `db:"account_id"` + AnotherAccountID string `db:"account__id"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + assert.Equal(t, "d5e49d3f", slice[i].AccountID) + assert.Equal(t, "5e49d321", slice[i].AnotherAccountID) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field First in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + func TestRowToStructByNameEmbeddedStruct(t *testing.T) { type Name struct { Last string `db:"last_name"` From 7fceb64deee5adddfcb40abacd271a74a596fca5 Mon Sep 17 00:00:00 2001 From: nolandseigler <57370691+nolandseigler@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:28:21 -0400 Subject: [PATCH 2/4] in rows.go 'fieldPosByName' use boolean to replace '_' and only execution replacements when there are no db tags present --- rows.go | 15 ++++++++++----- rows_test.go | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/rows.go b/rows.go index d4f7a9016..9dc66ac06 100644 --- a/rows.go +++ b/rows.go @@ -797,7 +797,7 @@ func computeNamedStructFields( if !dbTagPresent { colName = sf.Name } - fpos := fieldPosByName(fldDescs, colName) + fpos := fieldPosByName(fldDescs, colName, !dbTagPresent) if fpos == -1 { if missingField == "" { missingField = colName @@ -816,13 +816,18 @@ func computeNamedStructFields( const structTagKey = "db" -func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, replace bool) (i int) { i = -1 + if replace { + field = strings.ReplaceAll(field, "_", "") + } for i, desc := range fldDescs { // Snake case support. - field = strings.ReplaceAll(field, "_", "") - descName := strings.ReplaceAll(desc.Name, "_", "") + descName := desc.Name + if replace { + descName = strings.ReplaceAll(desc.Name, "_", "") + } if strings.EqualFold(descName, field) { return i @@ -848,4 +853,4 @@ func setupStructScanTargets(receiver any, fields []structRowField) []any { scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() } return scanTargets -} +} \ No newline at end of file diff --git a/rows_test.go b/rows_test.go index 5dbf952da..5494725cb 100644 --- a/rows_test.go +++ b/rows_test.go @@ -693,7 +693,7 @@ func TestRowToStructByNameDbTags(t *testing.T) { // check missing fields in a returned row rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) - assert.ErrorContains(t, err, "cannot find field First in returned row") + assert.ErrorContains(t, err, "cannot find field first_name in returned row") // check missing field in a destination struct rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id, null as ignore from generate_series(0, 9) n`) @@ -992,4 +992,4 @@ insert into products (name, price) values // Cheeseburger: $10 // Fries: $5 // Soft Drink: $3 -} +} \ No newline at end of file From b25d092d2043608c700b68d4164d58b93b02962f Mon Sep 17 00:00:00 2001 From: nolandseigler <57370691+nolandseigler@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:30:28 -0400 Subject: [PATCH 3/4] formatting --- rows.go | 2 +- rows_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rows.go b/rows.go index 9dc66ac06..459e20968 100644 --- a/rows.go +++ b/rows.go @@ -853,4 +853,4 @@ func setupStructScanTargets(receiver any, fields []structRowField) []any { scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() } return scanTargets -} \ No newline at end of file +} diff --git a/rows_test.go b/rows_test.go index 5494725cb..4cda957fc 100644 --- a/rows_test.go +++ b/rows_test.go @@ -992,4 +992,4 @@ insert into products (name, price) values // Cheeseburger: $10 // Fries: $5 // Soft Drink: $3 -} \ No newline at end of file +} From 71a8e53574e1e60f7f75d6873d1f98d0ccea31bf Mon Sep 17 00:00:00 2001 From: nolandseigler <57370691+nolandseigler@users.noreply.github.com> Date: Fri, 12 Jul 2024 08:50:54 -0400 Subject: [PATCH 4/4] use normalized equality or strict equality check in rows.go fieldPosByName --- rows.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/rows.go b/rows.go index 459e20968..f23625d4c 100644 --- a/rows.go +++ b/rows.go @@ -816,21 +816,21 @@ func computeNamedStructFields( const structTagKey = "db" -func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, replace bool) (i int) { +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize bool) (i int) { i = -1 - if replace { + + if normalize { field = strings.ReplaceAll(field, "_", "") } for i, desc := range fldDescs { - - // Snake case support. - descName := desc.Name - if replace { - descName = strings.ReplaceAll(desc.Name, "_", "") - } - - if strings.EqualFold(descName, field) { - return i + if normalize { + if strings.EqualFold(strings.ReplaceAll(desc.Name, "_", ""), field) { + return i + } + } else { + if desc.Name == field { + return i + } } } return