diff --git a/conn.go b/conn.go index 311721459..b69e4b48c 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "regexp" "strconv" "strings" "time" @@ -107,8 +108,10 @@ var ( ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { @@ -1393,3 +1395,254 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error return nil } + +/* +buildLoadTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // collect all types. Not currently recommended. + typeNamesClause = "IS NOT NULL" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT oid + FROM pg_type + WHERE typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) ) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type TypeInfo struct { + oid, typbasetype, typelem, rngsubtype, rngtypid uint32 + typeName, typtype string + attnames []string + atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*TypeInfo, error) { + if typeNames == nil || len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + serverVersion, err := c.serverVersion() + if err != nil { + return nil, fmt.Errorf("Unexpected server version error: %w", err) + } + sql := buildLoadTypesSQL(serverVersion, typeNames) + var rows Rows + if typeNames == nil { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) + } else { + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + } + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*TypeInfo, 0, 100) + for rows.Next() { + ti := TypeInfo{} + err = rows.Scan(&ti.typeName, &ti.typtype, &ti.typbasetype, &ti.typelem, &ti.oid, &ti.rngtypid, &ti.rngsubtype, &ti.attnames, &ti.atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + result = append(result, &ti) + } + return result, nil +} + +// RegisterTypes complements LoadTypes, applying the type information collected by LoadTypes +// to the connection's typemap. +func (c *Conn) RegisterTypes(typeInfo []*TypeInfo, registerWith *pgtype.Map) error { + if registerWith == nil { + return fmt.Errorf("Type map must be supplied") + } + for _, ti := range typeInfo { + switch ti.typtype { + case "b": // array + dt, ok := registerWith.TypeForOID(ti.typelem) + if !ok { + return fmt.Errorf("array element OID %v not registered while loading for %v", ti.typelem, ti.typeName) + } + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}) + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.attnames { + //if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil { + // return nil, fmt.Errorf("While extracting OID used in composite field: %w", err) + //} + dt, ok := registerWith.TypeForOID(ti.atttypids[i]) + if !ok { + return fmt.Errorf("unknown composite type field OID %v (%v)", ti.atttypids[i], fieldName) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.CompositeCodec{Fields: fields}}) + case "d": // domain + dt, ok := registerWith.TypeForOID(ti.typbasetype) + if !ok { + return fmt.Errorf("domain base type OID %v was not already registered, needed for %v", ti.typbasetype, ti.typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: dt.Codec}) + case "e": // enum + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.EnumCodec{}}) + case "r": // range + dt, ok := registerWith.TypeForOID(ti.rngsubtype) + if !ok { + return fmt.Errorf("range element OID %v not registered for %v", ti.rngsubtype, ti.typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.RangeCodec{ElementType: dt}}) + case "m": // multirange + dt, ok := registerWith.TypeForOID(ti.rngtypid) + if !ok { + return fmt.Errorf("multirange element OID %v not registered while loading %v", ti.rngtypid, ti.typeName) + } + + registerWith.RegisterType(&pgtype.Type{Name: ti.typeName, OID: ti.oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}) + default: + return fmt.Errorf("unknown typtype %v for %v", ti.typtype, ti.typeName) + } + } + return nil +} + +// serverVersion returns the postgresql server version. +func (conn *Conn) serverVersion() (int64, error) { + serverVersionStr := conn.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return serverVersion, nil +} diff --git a/go.sum b/go.sum index 4b02a0365..29fe452b2 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index a049b448e..338982879 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -10,11 +10,58 @@ import ( "github.com/stretchr/testify/require" ) -func TestCompositeCodecTranscode(t *testing.T) { +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `drop domain if exists anotheruint64; +drop type if exists ct_test; +create domain anotheruint64 as numeric(20,0); + +create type ct_test as ( + a text, + b int4, + c anotheruint64 +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type ct_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + types, err := conn.LoadTypes(ctx, []string{"ct_test"}) + require.NoError(t, err) + err = conn.RegisterTypes(types, conn.TypeMap()) + require.NoError(t, err) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + for _, format := range formats { + var a string + var b int32 + var c uint64 + + err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code}, + pgtype.CompositeFields{"hi", int32(42), uint64(123)}, + ).Scan( + pgtype.CompositeFields{&a, &b, &c}, + ) + require.NoErrorf(t, err, "%v", format.name) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + require.EqualValuesf(t, 123, c, "%v", format.name) + } + }) +} + +func TestCompositeCodecTranscode(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { _, err := conn.Exec(ctx, `drop type if exists ct_test; create type ct_test as ( @@ -94,7 +141,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -131,7 +177,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -172,7 +217,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -217,7 +261,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) { skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop table if exists point3d; create table point3d (