Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ValueOrDefault method #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql/driver"
"encoding/json"
"errors"

"github.com/volatiletech/null/v9/convert"
)

Expand Down Expand Up @@ -148,3 +147,11 @@ func (b Bool) Value() (driver.Value, error) {
}
return b.Bool, nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (t Bool) ValueOrDefault() bool {
if !t.Valid {
return false
}
return t.Bool
}
12 changes: 12 additions & 0 deletions bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ func TestBoolScan(t *testing.T) {
assertNullBool(t, null, "scanned null")
}

func TestBoolValueOrDefault(t *testing.T) {
valid := NewBool(true, true)
if valid.ValueOrDefault() != true {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewBool(true, false)
if invalid.ValueOrDefault() != false {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertBool(t *testing.T, b Bool, from string) {
if b.Bool != true {
t.Errorf("bad %s bool: %v ≠ %v\n", from, b.Bool, true)
Expand Down
8 changes: 8 additions & 0 deletions byte.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,11 @@ func (b Byte) Value() (driver.Value, error) {
}
return []byte{b.Byte}, nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (t Byte) ValueOrDefault() byte {
if !t.Valid {
return 0
}
return t.Byte
}
12 changes: 12 additions & 0 deletions byte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ func TestByteScan(t *testing.T) {
assertNullByte(t, null, "scanned null")
}

func TestByteValueOrDefault(t *testing.T) {
valid := NewByte(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewByte(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertByte(t *testing.T, i Byte, from string) {
if i.Byte != 'b' {
t.Errorf("bad %s int: %d ≠ %d\n", from, i.Byte, 'b')
Expand Down
9 changes: 8 additions & 1 deletion bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"database/sql/driver"
"encoding/json"

"github.com/volatiletech/null/v9/convert"
)

Expand Down Expand Up @@ -139,3 +138,11 @@ func (b Bytes) Value() (driver.Value, error) {
}
return b.Bytes, nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (t Bytes) ValueOrDefault() []byte {
if !t.Valid {
return []byte{}
}
return t.Bytes
}
13 changes: 13 additions & 0 deletions bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package null
import (
"bytes"
"encoding/json"
"reflect"
"testing"
)

Expand Down Expand Up @@ -153,6 +154,18 @@ func TestBytesScan(t *testing.T) {
assertNullBytes(t, null, "scanned null")
}

func TestBytesValueOrDefault(t *testing.T) {
valid := NewBytes([]byte{1}, true)
if !reflect.DeepEqual(valid.ValueOrDefault(), []byte{1}) {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewBytes([]byte{1}, false)
if reflect.DeepEqual(valid.ValueOrDefault(), []byte{}) {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertBytes(t *testing.T, i Bytes, from string) {
if !bytes.Equal(i.Bytes, hello) {
t.Errorf("bad %s []byte: %v ≠ %v\n", from, string(i.Bytes), "hello")
Expand Down
11 changes: 9 additions & 2 deletions float32.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import (
"bytes"
"database/sql/driver"
"encoding/json"
"strconv"

"github.com/volatiletech/null/v9/convert"
"strconv"
)

// Float32 is a nullable float32.
Expand Down Expand Up @@ -137,3 +136,11 @@ func (f Float32) Value() (driver.Value, error) {
}
return float64(f.Float32), nil
}

// ValueOrDefault returns the inner value if valid, otherwise zero.
func (t Float32) ValueOrDefault() float32 {
if !t.Valid {
return 0.0
}
return t.Float32
}
12 changes: 12 additions & 0 deletions float32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ func TestFloat32Scan(t *testing.T) {
assertNullFloat32(t, null, "scanned null")
}

func TestFloat32ValueOrDefault(t *testing.T) {
valid := NewFloat32(1.0, true)
if valid.ValueOrDefault() != 1.0 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewFloat32(1.0, false)
if invalid.ValueOrDefault() != 0.0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertFloat32(t *testing.T, f Float32, from string) {
if f.Float32 != 1.2345 {
t.Errorf("bad %s float32: %f ≠ %f\n", from, f.Float32, 1.2345)
Expand Down
8 changes: 8 additions & 0 deletions float64.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,11 @@ func (f Float64) Value() (driver.Value, error) {
}
return f.Float64, nil
}

// ValueOrDefault returns the inner value if valid, otherwise zero.
func (t Float64) ValueOrDefault() float64 {
if !t.Valid {
return 0.0
}
return t.Float64
}
12 changes: 12 additions & 0 deletions float64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ func TestFloat64Scan(t *testing.T) {
assertNullFloat64(t, null, "scanned null")
}

func TestFloat64ValueOrDefault(t *testing.T) {
valid := NewFloat64(1.0, true)
if valid.ValueOrDefault() != 1.0 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewFloat64(1.0, false)
if invalid.ValueOrDefault() != 0.0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertFloat64(t *testing.T, f Float64, from string) {
if f.Float64 != 1.2345 {
t.Errorf("bad %s float64: %f ≠ %f\n", from, f.Float64, 1.2345)
Expand Down
8 changes: 8 additions & 0 deletions int.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,11 @@ func (i Int) Value() (driver.Value, error) {
}
return int64(i.Int), nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (i Int) ValueOrDefault() int {
if !i.Valid {
return 0
}
return i.Int
}
8 changes: 8 additions & 0 deletions int16.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,11 @@ func (i Int16) Value() (driver.Value, error) {
}
return int64(i.Int16), nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (i Int16) ValueOrDefault() int16 {
if !i.Valid {
return 0
}
return i.Int16
}
12 changes: 12 additions & 0 deletions int16_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ func TestInt16Scan(t *testing.T) {
assertNullInt16(t, null, "scanned null")
}

func TestInt16ValueOrDefault(t *testing.T) {
valid := NewInt16(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewInt16(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertInt16(t *testing.T, i Int16, from string) {
if i.Int16 != 32766 {
t.Errorf("bad %s int16: %d ≠ %d\n", from, i.Int16, 32766)
Expand Down
8 changes: 8 additions & 0 deletions int32.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,11 @@ func (i Int32) Value() (driver.Value, error) {
}
return int64(i.Int32), nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (i Int32) ValueOrDefault() int32 {
if !i.Valid {
return 0
}
return i.Int32
}
12 changes: 12 additions & 0 deletions int32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ func TestInt32Scan(t *testing.T) {
assertNullInt32(t, null, "scanned null")
}

func TestInt32ValueOrDefault(t *testing.T) {
valid := NewInt32(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewInt32(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertInt32(t *testing.T, i Int32, from string) {
if i.Int32 != 2147483646 {
t.Errorf("bad %s int32: %d ≠ %d\n", from, i.Int32, 2147483646)
Expand Down
8 changes: 8 additions & 0 deletions int64.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,11 @@ func (i Int64) Value() (driver.Value, error) {
}
return i.Int64, nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (i Int64) ValueOrDefault() int64 {
if !i.Valid {
return 0
}
return i.Int64
}
12 changes: 12 additions & 0 deletions int64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ func TestInt64Scan(t *testing.T) {
assertNullInt64(t, null, "scanned null")
}

func TestInt64ValueOrDefault(t *testing.T) {
valid := NewInt64(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewInt64(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertInt64(t *testing.T, i Int64, from string) {
if i.Int64 != 9223372036854775806 {
t.Errorf("bad %s int64: %d ≠ %d\n", from, i.Int64, 9223372036854775806)
Expand Down
8 changes: 8 additions & 0 deletions int8.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,11 @@ func (i Int8) Value() (driver.Value, error) {
}
return int64(i.Int8), nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (i Int8) ValueOrDefault() int8 {
if !i.Valid {
return 0
}
return i.Int8
}
12 changes: 12 additions & 0 deletions int8_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ func TestInt8Scan(t *testing.T) {
assertNullInt8(t, null, "scanned null")
}

func TestInt8ValueOrDefault(t *testing.T) {
valid := NewInt8(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewInt8(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertInt8(t *testing.T, i Int8, from string) {
if i.Int8 != 126 {
t.Errorf("bad %s int8: %d ≠ %d\n", from, i.Int8, 126)
Expand Down
12 changes: 12 additions & 0 deletions int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ func TestIntScan(t *testing.T) {
assertNullInt(t, null, "scanned null")
}

func TestIntValueOrDefault(t *testing.T) {
valid := NewInt(1, true)
if valid.ValueOrDefault() != 1 {
t.Error("unexpected ValueOrDefault", valid.ValueOrDefault())
}

invalid := NewInt(1, false)
if invalid.ValueOrDefault() != 0 {
t.Error("unexpected ValueOrDefault", invalid.ValueOrDefault())
}
}

func assertInt(t *testing.T, i Int, from string) {
if i.Int != 12345 {
t.Errorf("bad %s int: %d ≠ %d\n", from, i.Int, 12345)
Expand Down
15 changes: 11 additions & 4 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"

"github.com/volatiletech/null/v9/convert"
)

Expand Down Expand Up @@ -80,9 +79,9 @@ func (j JSON) Unmarshal(dest interface{}) error {
//
// Example if you have a struct with a null.JSON called v:
//
// {} -> does not call unmarshaljson: !set & !valid
// {"v": null} -> calls unmarshaljson, set & !valid
// {"v": {}} -> calls unmarshaljson, set & valid (json value is '{}')
// {} -> does not call unmarshaljson: !set & !valid
// {"v": null} -> calls unmarshaljson, set & !valid
// {"v": {}} -> calls unmarshaljson, set & valid (json value is '{}')
//
// That's to say if 'null' is passed in at the json level we do not capture that
// value - instead we set the value-level null flag so that an sql value will
Expand Down Expand Up @@ -188,3 +187,11 @@ func (j JSON) Value() (driver.Value, error) {
}
return j.JSON, nil
}

// ValueOrDefault returns the inner value if valid, otherwise default.
func (t JSON) ValueOrDefault() []byte {
if !t.Valid {
return []byte{}
}
return t.JSON
}
Loading