diff --git a/float32.go b/float32.go index d6373ba..8096ffc 100644 --- a/float32.go +++ b/float32.go @@ -4,6 +4,9 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" + "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -44,14 +47,38 @@ func (f *Float32) UnmarshalJSON(data []byte) error { return nil } - var x float64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - f.Float32 = float32(x) - f.Valid = true - return nil + var r float64 + switch x := v.(type) { + case float64: + r = v.(float64) + case string: + str := string(x) + if len(str) == 0 { + f.Valid = false + return nil + } + + r, err = strconv.ParseFloat(str, 32) + case nil: + f.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Float32", reflect.TypeOf(v).Name()) + } + + if r > math.MaxFloat32 { + return fmt.Errorf("json: %f overflows max float32 value", r) + } + + f.Float32 = float32(r) + f.Valid = (err == nil) && (f.Float32 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/float32_test.go b/float32_test.go index c29350a..7586d3b 100644 --- a/float32_test.go +++ b/float32_test.go @@ -6,7 +6,8 @@ import ( ) var ( - float32JSON = []byte(`1.2345`) + float32JSON = []byte(`1.2345`) + float32StringJSON = []byte(`"1.2345"`) ) func TestFloat32From(t *testing.T) { @@ -35,11 +36,21 @@ func TestUnmarshalFloat32(t *testing.T) { maybePanic(err) assertFloat32(t, f, "float32 json") + var sf Float32 + err = json.Unmarshal(float32StringJSON, &sf) + maybePanic(err) + assertFloat32(t, sf, "float32 string json") + var null Float32 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullFloat32(t, null, "null json") + var bf Float32 + err = json.Unmarshal(blankStringJSON, &bf) + maybePanic(err) + assertNullFloat32(t, bf, "blank json string") + var badType Float32 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/float64.go b/float64.go index 49d1355..effc101 100644 --- a/float64.go +++ b/float64.go @@ -4,6 +4,8 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -43,13 +45,29 @@ func (f *Float64) UnmarshalJSON(data []byte) error { f.Valid = false return nil } - - if err := json.Unmarshal(data, &f.Float64); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - - f.Valid = true - return nil + switch x := v.(type) { + case float64: + f.Float64 = float64(x) + case string: + str := string(x) + if len(str) == 0 { + f.Valid = false + return nil + } + f.Float64, err = strconv.ParseFloat(str, 64) + case nil: + f.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Float64", reflect.TypeOf(v).Name()) + } + f.Valid = err == nil + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/float64_test.go b/float64_test.go index 11d20d0..24d4fad 100644 --- a/float64_test.go +++ b/float64_test.go @@ -6,7 +6,8 @@ import ( ) var ( - float64JSON = []byte(`1.2345`) + float64JSON = []byte(`1.2345`) + float64StringJSON = []byte(`"1.2345"`) ) func TestFloat64From(t *testing.T) { @@ -35,11 +36,21 @@ func TestUnmarshalFloat64(t *testing.T) { maybePanic(err) assertFloat64(t, f, "float64 json") + var sf Float64 + err = json.Unmarshal(float64StringJSON, &sf) + maybePanic(err) + assertFloat64(t, sf, "float64 string json") + var null Float64 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullFloat64(t, null, "null json") + var bf Float64 + err = json.Unmarshal(blankStringJSON, &bf) + maybePanic(err) + assertNullFloat64(t, bf, "blank json string") + var badType Float64 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/int.go b/int.go index 742c228..db8eb33 100644 --- a/int.go +++ b/int.go @@ -4,7 +4,9 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -45,14 +47,31 @@ func (i *Int) UnmarshalJSON(data []byte) error { return nil } - var x int64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } + switch x := v.(type) { + case float64: + // Unmarshal again, directly to int, to avoid intermediate float64 + err = json.Unmarshal(data, &i.Int) + case string: + str := string(x) + if len(str) == 0 { + i.Valid = false + return nil + } + i.Int, err = strconv.Atoi(str) + case nil: + i.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Int", reflect.TypeOf(v).Name()) + } - i.Int = int(x) - i.Valid = true - return nil + i.Valid = (err == nil) && (i.Int != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/int16.go b/int16.go index 445b20c..fb53e86 100644 --- a/int16.go +++ b/int16.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -46,18 +47,39 @@ func (i *Int16) UnmarshalJSON(data []byte) error { return nil } - var x int64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxInt16 { - return fmt.Errorf("json: %d overflows max int16 value", x) + var r int64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to int64, to avoid intermediate float64 + err = json.Unmarshal(data, &r) + case string: + str := string(x) + if len(str) == 0 { + i.Valid = false + return nil + } + + r, err = strconv.ParseInt(str, 10, 16) + case nil: + i.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Int16", reflect.TypeOf(v).Name()) } - i.Int16 = int16(x) - i.Valid = true - return nil + if r > math.MaxInt16 { + return fmt.Errorf("json: %d overflows max int16 value", r) + } + + i.Int16 = int16(r) + i.Valid = (err == nil) && (i.Int16 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/int16_test.go b/int16_test.go index bbe1e32..78e035f 100644 --- a/int16_test.go +++ b/int16_test.go @@ -8,7 +8,8 @@ import ( ) var ( - int16JSON = []byte(`32766`) + int16JSON = []byte(`32766`) + int16StringJSON = []byte(`"32766"`) ) func TestInt16From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalInt16(t *testing.T) { maybePanic(err) assertInt16(t, i, "int16 json") + var si Int16 + err = json.Unmarshal(int16StringJSON, &si) + maybePanic(err) + assertInt16(t, si, "int16 string json") + var null Int16 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullInt16(t, null, "null json") + var bi Int16 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullInt16(t, bi, "blank json string") + var badType Int16 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/int32.go b/int32.go index eada3ef..9ef05e9 100644 --- a/int32.go +++ b/int32.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -47,18 +48,39 @@ func (i *Int32) UnmarshalJSON(data []byte) error { return nil } - var x int64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxInt32 { - return fmt.Errorf("json: %d overflows max int32 value", x) + var r int64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint64, to avoid intermediate float64 + err = json.Unmarshal(data, &r) + case string: + str := string(x) + if len(str) == 0 { + i.Valid = false + return nil + } + + r, err = strconv.ParseInt(str, 10, 32) + case nil: + i.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Int32", reflect.TypeOf(v).Name()) } - i.Int32 = int32(x) - i.Valid = true - return nil + if r > math.MaxInt32 { + return fmt.Errorf("json: %d overflows max int32 value", r) + } + + i.Int32 = int32(r) + i.Valid = (err == nil) && (i.Int32 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/int32_test.go b/int32_test.go index 683ca8d..77357f2 100644 --- a/int32_test.go +++ b/int32_test.go @@ -8,7 +8,8 @@ import ( ) var ( - int32JSON = []byte(`2147483646`) + int32JSON = []byte(`2147483646`) + int32StringJSON = []byte(`"2147483646"`) ) func TestInt32From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalInt32(t *testing.T) { maybePanic(err) assertInt32(t, i, "int32 json") + var si Int32 + err = json.Unmarshal(int32StringJSON, &si) + maybePanic(err) + assertInt32(t, si, "int32 string json") + var null Int32 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullInt32(t, null, "null json") + var bi Int32 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullInt32(t, bi, "blank json string") + var badType Int32 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/int64.go b/int64.go index 5dfbee5..4048c2d 100644 --- a/int64.go +++ b/int64.go @@ -4,6 +4,8 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -44,12 +46,31 @@ func (i *Int64) UnmarshalJSON(data []byte) error { return nil } - if err := json.Unmarshal(data, &i.Int64); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } + switch x := v.(type) { + case float64: + // Unmarshal again, directly to int64, to avoid intermediate float64 + err = json.Unmarshal(data, &i.Int64) + case string: + str := string(x) + if len(str) == 0 { + i.Valid = false + return nil + } + i.Int64, err = strconv.ParseInt(str, 10, 64) + case nil: + i.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Int64", reflect.TypeOf(v).Name()) + } - i.Valid = true - return nil + i.Valid = (err == nil) && (i.Int64 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/int64_test.go b/int64_test.go index 8fb11a7..495a314 100644 --- a/int64_test.go +++ b/int64_test.go @@ -8,7 +8,8 @@ import ( ) var ( - int64JSON = []byte(`9223372036854775806`) + int64JSON = []byte(`9223372036854775806`) + int64StringJSON = []byte(`"9223372036854775806"`) ) func TestInt64From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalInt64(t *testing.T) { maybePanic(err) assertInt64(t, i, "int64 json") + var si Int64 + err = json.Unmarshal(int64StringJSON, &si) + maybePanic(err) + assertInt64(t, si, "int64 string json") + var null Int64 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullInt64(t, null, "null json") + var bi Int64 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullInt64(t, bi, "blank json string") + var badType Int64 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/int8.go b/int8.go index c6682bb..9651582 100644 --- a/int8.go +++ b/int8.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -46,18 +47,39 @@ func (i *Int8) UnmarshalJSON(data []byte) error { return nil } - var x int64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxInt8 { - return fmt.Errorf("json: %d overflows max int8 value", x) + var r int64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to int64, to avoid intermediate float64 + err = json.Unmarshal(data, &r) + case string: + str := string(x) + if len(str) == 0 { + i.Valid = false + return nil + } + + r, err = strconv.ParseInt(str, 10, 8) + case nil: + i.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Int8", reflect.TypeOf(v).Name()) } - i.Int8 = int8(x) - i.Valid = true - return nil + if r > math.MaxInt8 { + return fmt.Errorf("json: %d overflows max int8 value", r) + } + + i.Int8 = int8(r) + i.Valid = (err == nil) && (i.Int8 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/int8_test.go b/int8_test.go index 1c06a38..5503821 100644 --- a/int8_test.go +++ b/int8_test.go @@ -8,7 +8,8 @@ import ( ) var ( - int8JSON = []byte(`126`) + int8JSON = []byte(`126`) + int8StringJSON = []byte(`126`) ) func TestInt8From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalInt8(t *testing.T) { maybePanic(err) assertInt8(t, i, "int8 json") + var si Int8 + err = json.Unmarshal(int8StringJSON, &si) + maybePanic(err) + assertInt8(t, si, "int8 string json") + var null Int8 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullInt8(t, null, "null json") + var bi Int8 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullInt8(t, bi, "blank json string") + var badType Int8 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/int_test.go b/int_test.go index d7aead9..7e0e71a 100644 --- a/int_test.go +++ b/int_test.go @@ -6,7 +6,8 @@ import ( ) var ( - intJSON = []byte(`12345`) + intJSON = []byte(`12345`) + intStringJSON = []byte(`"12345"`) ) func TestIntFrom(t *testing.T) { @@ -35,11 +36,21 @@ func TestUnmarshalInt(t *testing.T) { maybePanic(err) assertInt(t, i, "int json") + var si Int + err = json.Unmarshal(intStringJSON, &si) + maybePanic(err) + assertInt(t, si, "int string json") + var null Int err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullInt(t, null, "null json") + var bi Int + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullInt(t, bi, "blank json string") + var badType Int err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/uint.go b/uint.go index d750705..f97b446 100644 --- a/uint.go +++ b/uint.go @@ -4,6 +4,8 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -44,14 +46,35 @@ func (u *Uint) UnmarshalJSON(data []byte) error { return nil } - var x uint64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - u.Uint = uint(x) - u.Valid = true - return nil + var i uint64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint, to avoid intermediate float64 + err = json.Unmarshal(data, &i) + case string: + str := string(x) + if len(str) == 0 { + u.Valid = false + return nil + } + + i, err = strconv.ParseUint(str, 10, 64) + case nil: + u.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Uint", reflect.TypeOf(v).Name()) + } + + u.Uint = uint(i) + u.Valid = (err == nil) && (u.Uint != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/uint16.go b/uint16.go index 41ca4f5..b62d5e5 100644 --- a/uint16.go +++ b/uint16.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -46,18 +47,39 @@ func (u *Uint16) UnmarshalJSON(data []byte) error { return nil } - var x uint64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxUint16 { - return fmt.Errorf("json: %d overflows max uint8 value", x) + var i uint64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint64, to avoid intermediate float64 + err = json.Unmarshal(data, &i) + case string: + str := string(x) + if len(str) == 0 { + u.Valid = false + return nil + } + + i, err = strconv.ParseUint(str, 10, 16) + case nil: + u.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Uint16", reflect.TypeOf(v).Name()) } - u.Uint16 = uint16(x) - u.Valid = true - return nil + if i > math.MaxUint16 { + return fmt.Errorf("json: %d overflows max uint16 value", i) + } + + u.Uint16 = uint16(i) + u.Valid = (err == nil) && (u.Uint16 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/uint16_test.go b/uint16_test.go index e35d658..7bfb70e 100644 --- a/uint16_test.go +++ b/uint16_test.go @@ -8,7 +8,8 @@ import ( ) var ( - uint16JSON = []byte(`65534`) + uint16JSON = []byte(`65534`) + uint16StringJSON = []byte(`"65534"`) ) func TestUint16From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalUint16(t *testing.T) { maybePanic(err) assertUint16(t, i, "uint16 json") + var si Uint16 + err = json.Unmarshal(uint16StringJSON, &si) + maybePanic(err) + assertUint16(t, si, "uint16 string json") + var null Uint16 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullUint16(t, null, "null json") + var bi Uint16 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullUint16(t, bi, "blank json string") + var badType Uint16 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/uint32.go b/uint32.go index 35cbbf1..d43f419 100644 --- a/uint32.go +++ b/uint32.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -46,18 +47,39 @@ func (u *Uint32) UnmarshalJSON(data []byte) error { return nil } - var x uint64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxUint32 { - return fmt.Errorf("json: %d overflows max uint32 value", x) + var i uint64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint64, to avoid intermediate float64 + err = json.Unmarshal(data, &i) + case string: + str := string(x) + if len(str) == 0 { + u.Valid = false + return nil + } + + i, err = strconv.ParseUint(str, 10, 32) + case nil: + u.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Uint32", reflect.TypeOf(v).Name()) } - u.Uint32 = uint32(x) - u.Valid = true - return nil + if i > math.MaxUint32 { + return fmt.Errorf("json: %d overflows max uint32 value", i) + } + + u.Uint32 = uint32(i) + u.Valid = (err == nil) && (u.Uint32 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/uint32_test.go b/uint32_test.go index 27d25a0..c590060 100644 --- a/uint32_test.go +++ b/uint32_test.go @@ -8,7 +8,8 @@ import ( ) var ( - uint32JSON = []byte(`4294967294`) + uint32JSON = []byte(`4294967294`) + uint32StringJSON = []byte(`"4294967294"`) ) func TestUint32From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalUint32(t *testing.T) { maybePanic(err) assertUint32(t, i, "uint32 json") + var si Uint32 + err = json.Unmarshal(uint32StringJSON, &si) + maybePanic(err) + assertUint32(t, si, "uint32 string json") + var null Uint32 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullUint32(t, null, "null json") + var bi Uint32 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullUint32(t, bi, "blank json string") + var badType Uint32 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/uint64.go b/uint64.go index 61a29dd..857f692 100644 --- a/uint64.go +++ b/uint64.go @@ -4,6 +4,8 @@ import ( "bytes" "database/sql/driver" "encoding/json" + "fmt" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -44,12 +46,31 @@ func (u *Uint64) UnmarshalJSON(data []byte) error { return nil } - if err := json.Unmarshal(data, &u.Uint64); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint64, to avoid intermediate float64 + err = json.Unmarshal(data, &u.Uint64) + case string: + str := string(x) + if len(str) == 0 { + u.Valid = false + return nil + } + u.Uint64, err = strconv.ParseUint(str, 10, 64) + case nil: + u.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Uint64", reflect.TypeOf(v).Name()) + } - u.Valid = true - return nil + u.Valid = (err == nil) && (u.Uint64 != 0) + return err } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/uint64_test.go b/uint64_test.go index c0abc14..590c364 100644 --- a/uint64_test.go +++ b/uint64_test.go @@ -6,7 +6,8 @@ import ( ) var ( - uint64JSON = []byte(`18446744073709551614`) + uint64JSON = []byte(`18446744073709551614`) + uint64StringJSON = []byte(`"18446744073709551614"`) ) func TestUint64From(t *testing.T) { @@ -35,11 +36,21 @@ func TestUnmarshalUint64(t *testing.T) { maybePanic(err) assertUint64(t, i, "uint64 json") + var si Uint64 + err = json.Unmarshal(uint64StringJSON, &si) + maybePanic(err) + assertUint64(t, si, "uint64 string json") + var null Uint64 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullUint64(t, null, "null json") + var bi Uint64 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullUint64(t, bi, "blank json string") + var badType Uint64 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/uint8.go b/uint8.go index 1c6365c..389a127 100644 --- a/uint8.go +++ b/uint8.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "math" + "reflect" "strconv" "github.com/volatiletech/null/convert" @@ -46,18 +47,40 @@ func (u *Uint8) UnmarshalJSON(data []byte) error { return nil } - var x uint64 - if err := json.Unmarshal(data, &x); err != nil { + var err error + var v interface{} + if err = json.Unmarshal(data, &v); err != nil { return err } - if x > math.MaxUint8 { - return fmt.Errorf("json: %d overflows max uint8 value", x) + var i uint64 + switch x := v.(type) { + case float64: + // Unmarshal again, directly to uint64, to avoid intermediate float64 + err = json.Unmarshal(data, &i) + case string: + str := string(x) + if len(str) == 0 { + u.Valid = false + return nil + } + + i, err = strconv.ParseUint(str, 10, 8) + case nil: + u.Valid = false + return nil + default: + err = fmt.Errorf("json: cannot unmarshal %v into Go value of type null.Uint8", reflect.TypeOf(v).Name()) } - u.Uint8 = uint8(x) - u.Valid = true - return nil + if i > math.MaxUint8 { + return fmt.Errorf("json: %d overflows max uint8 value", i) + } + + u.Uint8 = uint8(i) + u.Valid = (err == nil) && (u.Uint8 != 0) + return err + } // UnmarshalText implements encoding.TextUnmarshaler. diff --git a/uint8_test.go b/uint8_test.go index cb0b80a..cf0e866 100644 --- a/uint8_test.go +++ b/uint8_test.go @@ -8,7 +8,8 @@ import ( ) var ( - uint8JSON = []byte(`254`) + uint8JSON = []byte(`254`) + uint8StringJSON = []byte(`"254"`) ) func TestUint8From(t *testing.T) { @@ -37,11 +38,21 @@ func TestUnmarshalUint8(t *testing.T) { maybePanic(err) assertUint8(t, i, "uint8 json") + var si Uint8 + err = json.Unmarshal(uint8StringJSON, &si) + maybePanic(err) + assertUint8(t, si, "uint8 string json") + var null Uint8 err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullUint8(t, null, "null json") + var bi Uint8 + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullUint8(t, bi, "blank json string") + var badType Uint8 err = json.Unmarshal(boolJSON, &badType) if err == nil { diff --git a/uint_test.go b/uint_test.go index 9ebd75f..0a7d281 100644 --- a/uint_test.go +++ b/uint_test.go @@ -6,7 +6,8 @@ import ( ) var ( - uintJSON = []byte(`12345`) + uintJSON = []byte(`12345`) + uintStringJSON = []byte(`"12345"`) ) func TestUintFrom(t *testing.T) { @@ -35,11 +36,21 @@ func TestUnmarshalUint(t *testing.T) { maybePanic(err) assertUint(t, i, "uint json") + var si Uint + err = json.Unmarshal(uintStringJSON, &si) + maybePanic(err) + assertUint(t, si, "uint string json") + var null Uint err = json.Unmarshal(nullJSON, &null) maybePanic(err) assertNullUint(t, null, "null json") + var bi Uint + err = json.Unmarshal(blankStringJSON, &bi) + maybePanic(err) + assertNullUint(t, bi, "blank json string") + var badType Uint err = json.Unmarshal(boolJSON, &badType) if err == nil {