Skip to content

Commit

Permalink
Simplify registration of custom types with pgxpool
Browse files Browse the repository at this point in the history
Sometimes pgx is not capable of inferring the correct codec by
inspecting the database, as is done with LoadType(s), and requires
a user-defined operation to be performed instead.

To allow these custom types to also benefit from connection
initialisation, it is possible to register these functions with pgxpool
using the new RegisterCustomType method.
When using the default configuration, this will still require each new
connection to perform one query to the backend to retrieve the OIDs for
these custom types. This is already a benefit, instead of requiring
a query for each custom type, with the associated latency.

Even better, when the reuseTypeMap pgxpool configuration is selected,
only the first connection requires this query; subsequent connections
will execute the custom registration code, using the cached OID mapping.
  • Loading branch information
nicois committed Jun 23, 2024
1 parent 2538c40 commit 0097c59
Showing 1 changed file with 99 additions and 21 deletions.
120 changes: 99 additions & 21 deletions pgxpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ type Pool struct {
closeOnce sync.Once
closeChan chan struct{}

autoLoadTypeNames []string
reuseTypeMap bool
autoLoadMutex *sync.Mutex
autoLoadTypes []*pgtype.Type
autoLoadTypeNames []string
reuseTypeMap bool
autoLoadMutex *sync.Mutex
autoLoadTypes []*pgtype.Type
customRegistrationMap map[string]CustomRegistrationFunction
customRegistrationMutex *sync.Mutex
customRegistrationOidMap map[string]uint32
}

// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
Expand Down Expand Up @@ -198,6 +201,10 @@ func New(ctx context.Context, connString string) (*Pool, error) {
return NewWithConfig(ctx, config)
}

// CustomRegistrationFunction is capable of registering whatever is necessary for
// a custom type. It is provided with the backend's OID for this type.
type CustomRegistrationFunction func(ctx context.Context, m *pgtype.Map, oid uint32) error

// NewWithConfig creates a new Pool. config must have been created by [ParseConfig].
func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
Expand All @@ -207,23 +214,25 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
}

p := &Pool{
config: config,
beforeConnect: config.BeforeConnect,
afterConnect: config.AfterConnect,
autoLoadTypeNames: config.AutoLoadTypes,
reuseTypeMap: config.ReuseTypeMaps,
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
beforeClose: config.BeforeClose,
minConns: config.MinConns,
maxConns: config.MaxConns,
maxConnLifetime: config.MaxConnLifetime,
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
maxConnIdleTime: config.MaxConnIdleTime,
healthCheckPeriod: config.HealthCheckPeriod,
healthCheckChan: make(chan struct{}, 1),
closeChan: make(chan struct{}),
autoLoadMutex: new(sync.Mutex),
config: config,
beforeConnect: config.BeforeConnect,
afterConnect: config.AfterConnect,
autoLoadTypeNames: config.AutoLoadTypes,
reuseTypeMap: config.ReuseTypeMaps,
beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease,
beforeClose: config.BeforeClose,
minConns: config.MinConns,
maxConns: config.MaxConns,
maxConnLifetime: config.MaxConnLifetime,
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
maxConnIdleTime: config.MaxConnIdleTime,
healthCheckPeriod: config.HealthCheckPeriod,
healthCheckChan: make(chan struct{}, 1),
closeChan: make(chan struct{}),
autoLoadMutex: new(sync.Mutex),
customRegistrationMap: make(map[string]CustomRegistrationFunction),
customRegistrationMutex: new(sync.Mutex),
}

if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
Expand Down Expand Up @@ -265,6 +274,24 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
}
}

if len(p.customRegistrationMap) > 0 {
oidMap, err := p.getOidMapForCustomRegistration(ctx, conn)
if err != nil {
conn.Close(ctx)
return nil, fmt.Errorf("While retrieving OIDs for custom type registration: %w", err)
}
for typeName, f := range p.customRegistrationMap {
if oid, exists := oidMap[typeName]; exists {
if err := f(ctx, conn.TypeMap(), oid); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("Type %q does not have an associated OID.", typeName)
}
}

}

if p.autoLoadTypeNames != nil && len(p.autoLoadTypeNames) > 0 {
types, err := p.loadTypes(ctx, conn, p.autoLoadTypeNames)
if err != nil {
Expand Down Expand Up @@ -315,6 +342,51 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
return p, nil
}

func (p *Pool) getOidMapForCustomRegistration(ctx context.Context, conn *pgx.Conn) (map[string]uint32, error) {
if p.reuseTypeMap {
p.customRegistrationMutex.Lock()
defer p.customRegistrationMutex.Unlock()
if p.customRegistrationOidMap != nil {
return p.customRegistrationOidMap, nil
}
oidMap, err := p.fetchOidMapForCustomRegistration(ctx, conn)
if err != nil {
return nil, err
}
p.customRegistrationOidMap = oidMap
return oidMap, nil
}
// Avoid needing to acquire the mutex and allow connections to initialise in parallel
// if we have chosen to not reuse the type mapping
return p.fetchOidMapForCustomRegistration(ctx, conn)
}

func (p *Pool) fetchOidMapForCustomRegistration(ctx context.Context, conn *pgx.Conn) (map[string]uint32, error) {
sql := `
SELECT oid, typname
FROM pg_type
WHERE typname = ANY($1)`
result := make(map[string]uint32)
typeNames := make([]string, 0, len(p.customRegistrationMap))
for typeName := range p.customRegistrationMap {
typeNames = append(typeNames, typeName)
}
rows, err := conn.Query(ctx, sql, typeNames)
if err != nil {
return nil, fmt.Errorf("While collecting OIDs for custom registrations: %w", err)
}
defer rows.Close()
var typeName string
var oid uint32
for rows.Next() {
if err := rows.Scan(&typeName, &oid); err != nil {
return nil, fmt.Errorf("While scanning a row for custom registrations: %w", err)
}
result[typeName] = oid
}
return result, nil
}

// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the
// addition of the following variables:
//
Expand Down Expand Up @@ -425,6 +497,12 @@ func (p *Pool) Close() {
})
}

// RegisterCustomType is used to provide a function capable of performing
// type registration for situations where the autoloader is unable to do so on its own
func (p *Pool) RegisterCustomType(typeName string, f CustomRegistrationFunction) {
p.customRegistrationMap[typeName] = f
}

// loadTypes is used internally to autoload the custom types for a connection,
// potentially reusing previously-loaded typemap information.
func (p *Pool) loadTypes(ctx context.Context, conn *pgx.Conn, typeNames []string) ([]*pgtype.Type, error) {
Expand Down

0 comments on commit 0097c59

Please sign in to comment.