Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Scan and Value to pgtype.FlatArray, pgtype.Array, pgtype.Range, and pgtype.Multirange #2020

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions pgtype/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgtype

import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -419,6 +420,53 @@ func (a Array[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// Array needs a *Map to decode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (a *Array[T]) Scan(v any) error {
if v == nil {
*a = Array[T]{}
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), a)
case []byte:
return m.Scan(0, 0, v, a)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// Array needs a *Map to encode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src Array[T]) Value() (driver.Value, error) {
if !src.Valid {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}

// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
// and custom lower bounds. Use Array to preserve these.
type FlatArray[T any] []T
Expand Down Expand Up @@ -458,3 +506,78 @@ func (a FlatArray[T]) ScanIndex(i int) any {
func (a FlatArray[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// FlatArray needs a *Map to decode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (a *FlatArray[T]) Scan(v any) error {
if v == nil {
*a = nil
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), a)
case []byte:
return m.Scan(0, 0, v, a)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// FlatArray needs a *Map to encode the array elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src FlatArray[T]) Value() (driver.Value, error) {
if src == nil {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}

// flatArrayForWrapper is FlatArray without the Scan and Value methods. The avoids a wrapped plan attempt always
// "succeeding".
type flatArrayForWrapper[T any] []T

func (a flatArrayForWrapper[T]) Dimensions() []ArrayDimension {
return FlatArray[T](a).Dimensions()
}

func (a flatArrayForWrapper[T]) Index(i int) any {
return FlatArray[T](a).Index(i)
}

func (a flatArrayForWrapper[T]) IndexType() any {
return FlatArray[T](a).IndexType()
}

func (a *flatArrayForWrapper[T]) SetDimensions(dimensions []ArrayDimension) error {
return (*FlatArray[T])(a).SetDimensions(dimensions)
}

func (a flatArrayForWrapper[T]) ScanIndex(i int) any {
return FlatArray[T](a).ScanIndex(i)
}

func (a flatArrayForWrapper[T]) ScanIndexType() any {
return FlatArray[T](a).ScanIndexType()
}
47 changes: 47 additions & 0 deletions pgtype/multirange.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,50 @@ func (r Multirange[T]) ScanIndex(i int) any {
func (r Multirange[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// Range needs a *Map to decode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (r *Multirange[T]) Scan(v any) error {
if v == nil {
*r = nil
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), r)
case []byte:
return m.Scan(0, 0, v, r)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// Range needs a *Map to encode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src Multirange[T]) Value() (driver.Value, error) {
if src == nil {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}
41 changes: 25 additions & 16 deletions pgtype/pgtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net"
"net/netip"
"reflect"
"sync"
"time"
)

Expand Down Expand Up @@ -123,6 +124,14 @@ const (
Int8multirangeArrayOID = 6157
)

// databaseSQLMapPool is a sync.Pool that holds *Map instances used for implementing sql.Scanner and driver.Valuer on
// types that need a *Map to encode and decode such as FlatArray[T].
var databaseSQLMapPool = sync.Pool{
New: func() any {
return NewMap()
},
}

type InfinityModifier int8

const (
Expand Down Expand Up @@ -932,19 +941,19 @@ func TryWrapPtrSliceScanPlan(target any) (plan WrappedScanPlanNextSetter, nextVa
// Avoid using reflect path for common types.
switch target := target.(type) {
case *[]int16:
return &wrapPtrSliceScanPlan[int16]{}, (*FlatArray[int16])(target), true
return &wrapPtrSliceScanPlan[int16]{}, (*flatArrayForWrapper[int16])(target), true
case *[]int32:
return &wrapPtrSliceScanPlan[int32]{}, (*FlatArray[int32])(target), true
return &wrapPtrSliceScanPlan[int32]{}, (*flatArrayForWrapper[int32])(target), true
case *[]int64:
return &wrapPtrSliceScanPlan[int64]{}, (*FlatArray[int64])(target), true
return &wrapPtrSliceScanPlan[int64]{}, (*flatArrayForWrapper[int64])(target), true
case *[]float32:
return &wrapPtrSliceScanPlan[float32]{}, (*FlatArray[float32])(target), true
return &wrapPtrSliceScanPlan[float32]{}, (*flatArrayForWrapper[float32])(target), true
case *[]float64:
return &wrapPtrSliceScanPlan[float64]{}, (*FlatArray[float64])(target), true
return &wrapPtrSliceScanPlan[float64]{}, (*flatArrayForWrapper[float64])(target), true
case *[]string:
return &wrapPtrSliceScanPlan[string]{}, (*FlatArray[string])(target), true
return &wrapPtrSliceScanPlan[string]{}, (*flatArrayForWrapper[string])(target), true
case *[]time.Time:
return &wrapPtrSliceScanPlan[time.Time]{}, (*FlatArray[time.Time])(target), true
return &wrapPtrSliceScanPlan[time.Time]{}, (*flatArrayForWrapper[time.Time])(target), true
}

targetType := reflect.TypeOf(target)
Expand All @@ -968,7 +977,7 @@ type wrapPtrSliceScanPlan[T any] struct {
func (plan *wrapPtrSliceScanPlan[T]) SetNext(next ScanPlan) { plan.next = next }

func (plan *wrapPtrSliceScanPlan[T]) Scan(src []byte, target any) error {
return plan.next.Scan(src, (*FlatArray[T])(target.(*[]T)))
return plan.next.Scan(src, (*flatArrayForWrapper[T])(target.(*[]T)))
}

type wrapPtrSliceReflectScanPlan struct {
Expand Down Expand Up @@ -1773,19 +1782,19 @@ func TryWrapSliceEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextVa
// Avoid using reflect path for common types.
switch value := value.(type) {
case []int16:
return &wrapSliceEncodePlan[int16]{}, (FlatArray[int16])(value), true
return &wrapSliceEncodePlan[int16]{}, (flatArrayForWrapper[int16])(value), true
case []int32:
return &wrapSliceEncodePlan[int32]{}, (FlatArray[int32])(value), true
return &wrapSliceEncodePlan[int32]{}, (flatArrayForWrapper[int32])(value), true
case []int64:
return &wrapSliceEncodePlan[int64]{}, (FlatArray[int64])(value), true
return &wrapSliceEncodePlan[int64]{}, (flatArrayForWrapper[int64])(value), true
case []float32:
return &wrapSliceEncodePlan[float32]{}, (FlatArray[float32])(value), true
return &wrapSliceEncodePlan[float32]{}, (flatArrayForWrapper[float32])(value), true
case []float64:
return &wrapSliceEncodePlan[float64]{}, (FlatArray[float64])(value), true
return &wrapSliceEncodePlan[float64]{}, (flatArrayForWrapper[float64])(value), true
case []string:
return &wrapSliceEncodePlan[string]{}, (FlatArray[string])(value), true
return &wrapSliceEncodePlan[string]{}, (flatArrayForWrapper[string])(value), true
case []time.Time:
return &wrapSliceEncodePlan[time.Time]{}, (FlatArray[time.Time])(value), true
return &wrapSliceEncodePlan[time.Time]{}, (flatArrayForWrapper[time.Time])(value), true
}

if valueType := reflect.TypeOf(value); valueType != nil && valueType.Kind() == reflect.Slice {
Expand All @@ -1805,7 +1814,7 @@ type wrapSliceEncodePlan[T any] struct {
func (plan *wrapSliceEncodePlan[T]) SetNext(next EncodePlan) { plan.next = next }

func (plan *wrapSliceEncodePlan[T]) Encode(value any, buf []byte) (newBuf []byte, err error) {
return plan.next.Encode((FlatArray[T])(value.([]T)), buf)
return plan.next.Encode((flatArrayForWrapper[T])(value.([]T)), buf)
}

type wrapSliceEncodeReflectPlan struct {
Expand Down
48 changes: 48 additions & 0 deletions pgtype/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgtype

import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
)
Expand Down Expand Up @@ -320,3 +321,50 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error {
r.Valid = true
return nil
}

// Scan implements the database/sql Scanner interface.
//
// Range needs a *Map to decode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (a *Range[T]) Scan(v any) error {
if v == nil {
*a = Range[T]{}
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), a)
case []byte:
return m.Scan(0, 0, v, a)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// Range needs a *Map to encode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src Range[T]) Value() (driver.Value, error) {
if !src.Valid {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}
Loading