diff --git a/pgtype/json.go b/pgtype/json.go index f65fa492e..48b9f9771 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -143,10 +143,12 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP case BytesScanner: return scanPlanBinaryBytesToBytesScanner{} + } + // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. // // https://github.com/jackc/pgx/issues/1418 - case sql.Scanner: + if isSQLScanner(target) { return &scanPlanSQLScanner{formatCode: format} } @@ -155,6 +157,20 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP } } +// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner). +// +// https://github.com/jackc/pgx/issues/2146 +func isSQLScanner(v any) bool { + val := reflect.ValueOf(v) + for val.Kind() == reflect.Ptr { + if _, ok := val.Interface().(sql.Scanner); ok { + return true + } + val = val.Elem() + } + return false +} + type scanPlanAnyToString struct{} func (scanPlanAnyToString) Scan(src []byte, dst any) error { diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 3a8ddae5c..18ca5a8e4 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -63,6 +63,8 @@ func TestJSONCodec(t *testing.T) { // Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805) {Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))}, + // Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146) + {Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))}, }) pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ @@ -109,6 +111,31 @@ func (i Issue1805) MarshalJSON() ([]byte, error) { return nil, errors.New("MarshalJSON called") } +type Issue2146 int + +func (i *Issue2146) Scan(src any) error { + var source []byte + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("unknown source type") + } + var newI int + if err := json.Unmarshal(source, &newI); err != nil { + return err + } + *i = Issue2146(newI + 1) + return nil +} + +func (i Issue2146) Value() (driver.Value, error) { + b, err := json.Marshal(int(i - 1)) + return string(b), err +} + // https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 func TestJSONCodecUnmarshalSQLNull(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index e756dac03..f9d43edd7 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -396,7 +396,12 @@ type scanPlanSQLScanner struct { } func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { - scanner := dst.(sql.Scanner) + scanner := getSQLScanner(dst) + + if scanner == nil { + return fmt.Errorf("cannot scan into %T", dst) + } + if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -408,6 +413,21 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { } } +// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively +func getSQLScanner(target any) sql.Scanner { + val := reflect.ValueOf(target) + for val.Kind() == reflect.Ptr { + if _, ok := val.Interface().(sql.Scanner); ok { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + return val.Interface().(sql.Scanner) + } + val = val.Elem() + } + return nil +} + type scanPlanString struct{} func (scanPlanString) Scan(src []byte, dst any) error { diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index e8613aee1..eb1369403 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "os" + "reflect" "regexp" "strconv" "testing" @@ -631,3 +632,10 @@ func isExpectedEq(a any) func(any) bool { return a == v } } + +func isPtrExpectedEq(a any) func(any) bool { + return func(v any) bool { + val := reflect.ValueOf(v) + return a == val.Elem().Interface() + } +}