diff --git a/decimal.go b/decimal.go index 5a404c8..609c0ac 100644 --- a/decimal.go +++ b/decimal.go @@ -1847,6 +1847,14 @@ func (d *Decimal) Scan(value interface{}) error { *d = New(v, 0) return nil + case uint64: + // while clickhouse may send 0 in db as uint64 + *d = NewFromUint64(v) + return nil + + case nil: + return nil + default: // default is trying to interpret value stored as string str, err := unquoteIfQuoted(v) diff --git a/decimal_test.go b/decimal_test.go index 60993e0..d398f2d 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2416,104 +2416,57 @@ func TestDecimal_Max(t *testing.T) { } } -func TestDecimal_Scan(t *testing.T) { - // test the Scan method that implements the - // sql.Scanner interface - // check for the for different type of values - // that are possible to be received from the database - // drivers +func scanHelper(t *testing.T, dbval interface{}, expected Decimal) { + t.Helper() - // in normal operations the db driver (sqlite at least) - // will return an int64 if you specified a numeric format a := Decimal{} - dbvalue := 54.33 - expected := NewFromFloat(dbvalue) - - err := a.Scan(dbvalue) - if err != nil { + if err := a.Scan(dbval); err != nil { // Scan failed... no need to test result value - t.Errorf("a.Scan(54.33) failed with message: %s", err) - - } else { + t.Errorf("a.Scan(%v) failed with message: %s", dbval, err) + } else if !a.Equal(expected) { // Scan succeeded... test resulting values - if !a.Equal(expected) { - t.Errorf("%s does not equal to %s", a, expected) - } + t.Errorf("%s does not equal to %s", a, expected) } +} + +func TestDecimal_Scan(t *testing.T) { + // test the Scan method that implements the sql.Scanner interface + // check different types received from various database drivers + + dbvalue := 54.33 + expected := NewFromFloat(dbvalue) + scanHelper(t, dbvalue, expected) // apparently MySQL 5.7.16 and returns these as float32 so we need // to handle these as well dbvalueFloat32 := float32(54.33) expected = NewFromFloat(float64(dbvalueFloat32)) - - err = a.Scan(dbvalueFloat32) - if err != nil { - // Scan failed... no need to test result value - t.Errorf("a.Scan(54.33) failed with message: %s", err) - - } else { - // Scan succeeded... test resulting values - if !a.Equal(expected) { - t.Errorf("%s does not equal to %s", a, expected) - } - } + scanHelper(t, dbvalueFloat32, expected) // at least SQLite returns an int64 when 0 is stored in the db // and you specified a numeric format on the schema dbvalueInt := int64(0) expected = New(dbvalueInt, 0) + scanHelper(t, dbvalueInt, expected) - err = a.Scan(dbvalueInt) - if err != nil { - // Scan failed... no need to test result value - t.Errorf("a.Scan(0) failed with message: %s", err) - - } else { - // Scan succeeded... test resulting values - if !a.Equal(expected) { - t.Errorf("%s does not equal to %s", a, expected) - } - } + // also test uint64 + dbvalueUint64 := uint64(2) + expected = New(2, 0) + scanHelper(t, dbvalueUint64, expected) // in case you specified a varchar in your SQL schema, - // the database driver will return byte slice []byte + // the database driver may return either []byte or string valueStr := "535.666" dbvalueStr := []byte(valueStr) - expected, err = NewFromString(valueStr) - if err != nil { - t.Fatal(err) - } - - err = a.Scan(dbvalueStr) - if err != nil { - // Scan failed... no need to test result value - t.Errorf("a.Scan('535.666') failed with message: %s", err) - - } else { - // Scan succeeded... test resulting values - if !a.Equal(expected) { - t.Errorf("%s does not equal to %s", a, expected) - } - } - - // lib/pq can also return strings - expected, err = NewFromString(valueStr) + expected, err := NewFromString(valueStr) if err != nil { t.Fatal(err) } - - err = a.Scan(valueStr) - if err != nil { - // Scan failed... no need to test result value - t.Errorf("a.Scan('535.666') failed with message: %s", err) - } else { - // Scan succeeded... test resulting values - if !a.Equal(expected) { - t.Errorf("%s does not equal to %s", a, expected) - } - } + scanHelper(t, dbvalueStr, expected) + scanHelper(t, valueStr, expected) type foo struct{} + a := Decimal{} err = a.Scan(foo{}) if err == nil { t.Errorf("a.Scan(Foo{}) should have thrown an error but did not")