Skip to content

Commit

Permalink
Fix decoding of scientific notation (#463)
Browse files Browse the repository at this point in the history
* Fix scientific notation decoding and add encoding test cases

* Deal with ints and uints

* Add coverage for uint changes
  • Loading branch information
morris-kelly authored Jul 16, 2024
1 parent 1f84c0c commit b5f63d5
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 1 deletion.
34 changes: 34 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,21 @@ func (d *Decoder) fileToNode(f *ast.File) ast.Node {
func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) (reflect.Value, error) {
if typ.Kind() != reflect.String {
if !v.Type().ConvertibleTo(typ) {

// Special case for "strings -> floats" aka scientific notation
// If the destination type is a float and the source type is a string, check if we can
// use strconv.ParseFloat to convert the string to a float.
if (typ.Kind() == reflect.Float32 || typ.Kind() == reflect.Float64) &&
v.Type().Kind() == reflect.String {
if f, err := strconv.ParseFloat(v.String(), 64); err == nil {
if typ.Kind() == reflect.Float32 {
return reflect.ValueOf(float32(f)), nil
} else if typ.Kind() == reflect.Float64 {
return reflect.ValueOf(f), nil
}
// else, fall through to the error below
}
}
return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken())
}
return v.Convert(typ), nil
Expand Down Expand Up @@ -877,6 +892,15 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
dst.SetInt(int64(vv))
return nil
}
case string: // handle scientific notation
if i, err := strconv.ParseFloat(vv, 64); err == nil {
if 0 <= i && i <= math.MaxUint64 && !dst.OverflowInt(int64(i)) {
dst.SetInt(int64(i))
return nil
}
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
Expand All @@ -899,6 +923,16 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
dst.SetUint(uint64(vv))
return nil
}
case string: // handle scientific notation
if i, err := strconv.ParseFloat(vv, 64); err == nil {
if 0 <= i && i <= math.MaxUint64 && !dst.OverflowUint(uint64(i)) {
dst.SetUint(uint64(i))
return nil
}
} else { // couldn't be parsed as float
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}

default:
return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken())
}
Expand Down
96 changes: 95 additions & 1 deletion decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ func TestDecoder(t *testing.T) {
"v: 4294967295",
map[string]uint{"v": math.MaxUint32},
},
{
"v: 1e3",
map[string]uint{"v": 1000},
},

// uint64
{
Expand All @@ -271,6 +275,10 @@ func TestDecoder(t *testing.T) {
"v: 9223372036854775807",
map[string]uint64{"v": math.MaxInt64},
},
{
"v: 1e3",
map[string]uint64{"v": 1000},
},

// float32
{
Expand All @@ -289,6 +297,10 @@ func TestDecoder(t *testing.T) {
"v: 18446744073709551616",
map[string]float32{"v": float32(math.MaxUint64 + 1)},
},
{
"v: 1e-06",
map[string]float32{"v": 1e-6},
},

// float64
{
Expand All @@ -307,6 +319,10 @@ func TestDecoder(t *testing.T) {
"v: 18446744073709551616",
map[string]float64{"v": float64(math.MaxUint64 + 1)},
},
{
"v: 1e-06",
map[string]float64{"v": 1e-06},
},

// Timestamps
{
Expand Down Expand Up @@ -1093,6 +1109,73 @@ c:
}
}

func TestDecoder_ScientificNotation(t *testing.T) {
tests := []struct {
source string
value interface{}
}{
{
"v: 1e3",
map[string]uint{"v": 1000},
},
{
"v: 1e-3",
map[string]uint{"v": 0},
},
{
"v: 1e3",
map[string]int{"v": 1000},
},
{
"v: 1e-3",
map[string]int{"v": 0},
},
{
"v: 1e3",
map[string]float32{"v": 1000},
},
{
"v: 1.0e3",
map[string]float64{"v": 1000},
},
{
"v: 1e-3",
map[string]float64{"v": 0.001},
},
{
"v: 1.0e-3",
map[string]float64{"v": 0.001},
},
{
"v: 1.0e+3",
map[string]float64{"v": 1000},
},
{
"v: 1.0e+3",
map[string]float64{"v": 1000},
},
}
for _, test := range tests {
t.Run(test.source, func(t *testing.T) {
buf := bytes.NewBufferString(test.source)
dec := yaml.NewDecoder(buf)
typ := reflect.ValueOf(test.value).Type()
value := reflect.New(typ)
if err := dec.Decode(value.Interface()); err != nil {
if err == io.EOF {
return
}
t.Fatalf("%s: %+v", test.source, err)
}
actual := fmt.Sprintf("%+v", value.Elem().Interface())
expect := fmt.Sprintf("%+v", test.value)
if actual != expect {
t.Fatalf("failed to test [%s], actual=[%s], expect=[%s]", test.source, actual, expect)
}
})
}
}

func TestDecoder_TypeConversionError(t *testing.T) {
t.Run("type conversion for struct", func(t *testing.T) {
type T struct {
Expand All @@ -1115,6 +1198,17 @@ func TestDecoder_TypeConversionError(t *testing.T) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg)
}
})
t.Run("string to uint", func(t *testing.T) {
var v T
err := yaml.Unmarshal([]byte(`b: str`), &v)
if err == nil {
t.Fatal("expected to error")
}
msg := "cannot unmarshal string into Go struct field T.B of type uint"
if !strings.Contains(err.Error(), msg) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg)
}
})
t.Run("string to bool", func(t *testing.T) {
var v T
err := yaml.Unmarshal([]byte(`d: str`), &v)
Expand Down Expand Up @@ -2932,4 +3026,4 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) {
if val != "value" {
t.Fatalf("expected to have value \"value\", but got %q", val)
}
}
}
20 changes: 20 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ func TestEncoder(t *testing.T) {
map[string]float32{"v": 0.99},
nil,
},
{
"v: 1e-06\n",
map[string]float32{"v": 1e-06},
nil,
},
{
"v: 1e-06\n",
map[string]float64{"v": 0.000001},
nil,
},
{
"v: 0.123456789\n",
map[string]float64{"v": 0.123456789},
Expand All @@ -100,6 +110,16 @@ func TestEncoder(t *testing.T) {
map[string]float64{"v": 1000000},
nil,
},
{
"v: 1e-06\n",
map[string]float64{"v": 0.000001},
nil,
},
{
"v: 1e-06\n",
map[string]float64{"v": 1e-06},
nil,
},
{
"v: .inf\n",
map[string]interface{}{"v": math.Inf(0)},
Expand Down

0 comments on commit b5f63d5

Please sign in to comment.