Skip to content

Commit

Permalink
db: promote postgresql.scannerValuer and postgresql.valueWrapper to s…
Browse files Browse the repository at this point in the history
…qlbuilder
  • Loading branch information
xiam committed Aug 19, 2017
1 parent d3a9083 commit 92c8aa9
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 65 deletions.
10 changes: 5 additions & 5 deletions lib/sqlbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (b *sqlBuilder) PrepareContext(ctx context.Context, query interface{}) (*sq
case db.RawValue:
return b.PrepareContext(ctx, q.Raw())
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
return nil, fmt.Errorf("unsupported query type %T", query)
}
}

Expand All @@ -168,7 +168,7 @@ func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ..
case db.RawValue:
return b.ExecContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
return nil, fmt.Errorf("unsupported query type %T", query)
}
}

Expand All @@ -185,7 +185,7 @@ func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args .
case db.RawValue:
return b.QueryContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
return nil, fmt.Errorf("unsupported query type %T", query)
}
}

Expand All @@ -202,7 +202,7 @@ func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, arg
case db.RawValue:
return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("Unsupported query type %T.", query)
return nil, fmt.Errorf("unsupported query type %T", query)
}
}

Expand Down Expand Up @@ -404,7 +404,7 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err
case interface{}:
f[i] = exql.ColumnWithName(fmt.Sprintf("%v", v))
default:
return nil, nil, fmt.Errorf("Unexpected argument type %T for Select() argument.", v)
return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument", v)
}
}
return f, args, nil
Expand Down
10 changes: 5 additions & 5 deletions lib/sqlbuilder/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (

// Common error messages.
var (
ErrExpectingPointer = errors.New(`Argument must be an address.`)
ErrExpectingSlicePointer = errors.New(`Argument must be a slice address.`)
ErrExpectingSliceMapStruct = errors.New(`Argument must be a slice address of maps or structs.`)
ErrExpectingMapOrStruct = errors.New(`Argument must be either a map or a struct.`)
ErrExpectingPointerToEitherMapOrStruct = errors.New(`Expecting a pointer to either a map or a struct.`)
ErrExpectingPointer = errors.New(`argument must be an address`)
ErrExpectingSlicePointer = errors.New(`argument must be a slice address`)
ErrExpectingSliceMapStruct = errors.New(`argument must be a slice address of maps or structs`)
ErrExpectingMapOrStruct = errors.New(`argument must be either a map or a struct`)
ErrExpectingPointerToEitherMapOrStruct = errors.New(`expecting a pointer to either a map or a struct`)
)
8 changes: 4 additions & 4 deletions lib/sqlbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ func (sel *selector) Using(columns ...interface{}) Selector {

joins := len(sq.joins)
if joins == 0 {
return errors.New(`Cannot use Using() without a preceding Join() expression.`)
return errors.New(`cannot use Using() without a preceding Join() expression`)
}

lastJoin := sq.joins[joins-1]
if lastJoin.On != nil {
return errors.New(`Cannot use Using() and On() with the same Join() expression.`)
return errors.New(`cannot use Using() and On() with the same Join() expression`)
}

fragments, args, err := columnFragments(columns)
Expand Down Expand Up @@ -365,12 +365,12 @@ func (sel *selector) On(terms ...interface{}) Selector {
joins := len(sq.joins)

if joins == 0 {
return errors.New(`Cannot use On() without a preceding Join() expression.`)
return errors.New(`cannot use On() without a preceding Join() expression`)
}

lastJoin := sq.joins[joins-1]
if lastJoin.On != nil {
return errors.New(`Cannot use Using() and On() with the same Join() expression.`)
return errors.New(`cannot use Using() and On() with the same Join() expression`)
}

w, a := sel.SQLBuilder().t.toWhereWithArguments(terms)
Expand Down
123 changes: 74 additions & 49 deletions postgresql/custom_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,34 @@
package postgresql

import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"reflect"

"github.com/lib/pq"
"upper.io/db.v3/lib/sqlbuilder"
)

var (
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
sqlScannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
valueWrapperType = reflect.TypeOf((*valueWrapper)(nil)).Elem()
)

func Array(in interface{}) scannerValuer {
// Array returns a sqlbuilder.ScannerValuer for any given slice. Slice elements
// may require their own sqlbuilder.ScannerValuer.
func Array(in interface{}) sqlbuilder.ScannerValuer {
return pq.Array(in)
}

// Type JSONB represents a PostgreSQL's JSONB value.
// JSONB represents a PostgreSQL's JSONB value:
// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB
// satisfies sqlbuilder.ScannerValuer.
type JSONB struct {
V interface{}
}

// MarshalJSON implements json.Marshaler
// MarshalJSON encodes the wrapper value as JSON.
func (j JSONB) MarshalJSON() ([]byte, error) {
return json.Marshal(j.V)
}

// UnmarshalJSON implements json.Unmarshaler
// UnmarshalJSON decodes the given JSON into the wrapped value.
func (j *JSONB) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
Expand All @@ -61,7 +59,7 @@ func (j *JSONB) UnmarshalJSON(b []byte) error {
return nil
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (j *JSONB) Scan(src interface{}) error {
if src == nil {
j.V = nil
Expand All @@ -79,7 +77,7 @@ func (j *JSONB) Scan(src interface{}) error {
return nil
}

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (j JSONB) Value() (driver.Value, error) {
// See https://github.com/lib/pq/issues/528#issuecomment-257197239 on why are
// we returning string instead of []byte.
Expand All @@ -96,15 +94,17 @@ func (j JSONB) Value() (driver.Value, error) {
return string(b), nil
}

// Type StringArray is an alias for pq.StringArray
// StringArray represents a one-dimensional array of strings (`[]string{}`)
// that is compatible with PostgreSQL's text array (`text[]`). StringArray
// satisfies sqlbuilder.ScannerValuer.
type StringArray pq.StringArray

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
return pq.StringArray(a).Value()
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (a *StringArray) Scan(src interface{}) error {
s := pq.StringArray(*a)
if err := s.Scan(src); err != nil {
Expand All @@ -114,15 +114,17 @@ func (a *StringArray) Scan(src interface{}) error {
return nil
}

// Type Int64Array is an alias for pq.Int64Array
// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that
// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array
// satisfies sqlbuilder.ScannerValuer.
type Int64Array pq.Int64Array

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (i Int64Array) Value() (driver.Value, error) {
return pq.Int64Array(i).Value()
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (i *Int64Array) Scan(src interface{}) error {
s := pq.Int64Array(*i)
if err := s.Scan(src); err != nil {
Expand All @@ -132,15 +134,17 @@ func (i *Int64Array) Scan(src interface{}) error {
return nil
}

// Type Float64Array is an alias for pq.Float64Array
// Float64Array represents a one-dimensional array of float64s (`[]float64{}`)
// that is compatible with PostgreSQL's double precision array (`double
// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer.
type Float64Array pq.Float64Array

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (f Float64Array) Value() (driver.Value, error) {
return pq.Float64Array(f).Value()
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (f *Float64Array) Scan(src interface{}) error {
s := pq.Float64Array(*f)
if err := s.Scan(src); err != nil {
Expand All @@ -150,15 +154,17 @@ func (f *Float64Array) Scan(src interface{}) error {
return nil
}

// Type BoolArray is an alias for pq.BoolArray
// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that
// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray
// satisfies sqlbuilder.ScannerValuer.
type BoolArray pq.BoolArray

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (b BoolArray) Value() (driver.Value, error) {
return pq.BoolArray(b).Value()
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (b *BoolArray) Scan(src interface{}) error {
s := pq.BoolArray(*b)
if err := s.Scan(src); err != nil {
Expand All @@ -168,15 +174,18 @@ func (b *BoolArray) Scan(src interface{}) error {
return nil
}

// Type GenericArray is an alias for pq.GenericArray
// GenericArray represents a one-dimensional array of any type
// (`[]interface{}`) that is compatible with PostgreSQL's array type.
// GenericArray satisfies sqlbuilder.ScannerValuer and its elements may need to
// satisfy sqlbuilder.ScannerValuer too.
type GenericArray pq.GenericArray

// Value implements the driver.Valuer interface.
// Value satisfies the driver.Valuer interface.
func (g GenericArray) Value() (driver.Value, error) {
return pq.GenericArray(g).Value()
}

// Scan implements the sql.Scanner interface.
// Scan satisfies the sql.Scanner interface.
func (g *GenericArray) Scan(src interface{}) error {
s := pq.GenericArray(*g)
if err := s.Scan(src); err != nil {
Expand All @@ -186,23 +195,33 @@ func (g *GenericArray) Scan(src interface{}) error {
return nil
}

// JSONBMap represents a map of interfaces with string keys
// (`map[string]interface{}`) that is compatible with PostgreSQL's JSONB type.
// JSONBMap satisfies sqlbuilder.ScannerValuer.
type JSONBMap map[string]interface{}

// Value satisfies the driver.Valuer interface.
func (m JSONBMap) Value() (driver.Value, error) {
return ToJSONBValue(m)
}

// Scan satisfies the sql.Scanner interface.
func (m *JSONBMap) Scan(src interface{}) error {
*m = map[string]interface{}(nil)
return FromJSONBValue(m, src)
}

// JSONBArray represents an array of any type (`[]interface{}`) that is
// compatible with PostgreSQL's JSONB type. JSONBArray satisfies
// sqlbuilder.ScannerValuer.
type JSONBArray []interface{}

// Value satisfies the driver.Valuer interface.
func (a JSONBArray) Value() (driver.Value, error) {
return ToJSONBValue(a)
}

// Scan satisfies the sql.Scanner interface.
func (a *JSONBArray) Scan(src interface{}) error {
return FromJSONBValue(a, src)
}
Expand All @@ -220,38 +239,42 @@ func FromJSONBValue(dst interface{}, src interface{}) error {
return v.Scan(src)
}

type valueWrapper interface {
WrapValue(interface{}) interface{}
}

// JSONBConverter provides a helper method WrapValue that satisfies
// sqlbuilder.ValueWrapper, can be used to encode Go structs into JSONB
// PostgreSQL types and vice versa.
//
// Example:
//
// type MyCustomStruct struct {
// ID int64 `db:"id" json:"id"`
// Name string `db:"name" json:"name"`
// ...
// postgresql.JSONBConverter
// }
type JSONBConverter struct {
}

// WrapValue satisfies sqlbuilder.ValueWrapper
func (obj *JSONBConverter) WrapValue(src interface{}) interface{} {
return &JSONB{src}
}

type scannerValuer interface {
driver.Valuer
sql.Scanner
}

func autoWrap(elem reflect.Value, v interface{}) interface{} {
kind := elem.Kind()

if kind == reflect.Invalid {
return v
}

if elem.Type().Implements(sqlScannerType) {
if elem.Type().Implements(sqlbuilder.ScannerType) {
return v
}

if elem.Type().Implements(driverValuerType) {
if elem.Type().Implements(sqlbuilder.ValuerType) {
return v
}

if elem.Type().Implements(valueWrapperType) {
if elem.Type().Implements(sqlbuilder.ValueWrapperType) {
if elem.Type().Kind() == reflect.Ptr {
w := reflect.ValueOf(v)
if w.Kind() == reflect.Ptr {
Expand All @@ -260,7 +283,7 @@ func autoWrap(elem reflect.Value, v interface{}) interface{} {
return &JSONB{v}
}
}
vw := elem.Interface().(valueWrapper)
vw := elem.Interface().(sqlbuilder.ValueWrapper)
return vw.WrapValue(elem.Interface())
}

Expand All @@ -281,13 +304,15 @@ func autoWrap(elem reflect.Value, v interface{}) interface{} {
return v
}

// Type checks.
var (
_ valueWrapper = &JSONBConverter{}
_ scannerValuer = &StringArray{}
_ scannerValuer = &Int64Array{}
_ scannerValuer = &Float64Array{}
_ scannerValuer = &BoolArray{}
_ scannerValuer = &GenericArray{}
_ scannerValuer = &JSONBMap{}
_ scannerValuer = &JSONBArray{}
_ sqlbuilder.ValueWrapper = &JSONBConverter{}

_ sqlbuilder.ScannerValuer = &StringArray{}
_ sqlbuilder.ScannerValuer = &Int64Array{}
_ sqlbuilder.ScannerValuer = &Float64Array{}
_ sqlbuilder.ScannerValuer = &BoolArray{}
_ sqlbuilder.ScannerValuer = &GenericArray{}
_ sqlbuilder.ScannerValuer = &JSONBMap{}
_ sqlbuilder.ScannerValuer = &JSONBArray{}
)
2 changes: 1 addition & 1 deletion postgresql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (d *database) ConvertValues(values []interface{}) []interface{} {
case map[string]interface{}:
values[i] = (*JSONBMap)(&v)

case valueWrapper:
case sqlbuilder.ValueWrapper:
values[i] = v.WrapValue(v)

default:
Expand Down
Loading

0 comments on commit 92c8aa9

Please sign in to comment.