Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions union.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ func buildCodecForTypeDescribedBySliceTwoWayJSON(st map[string]*Codec, enclosing
}

func checkAll(allowedTypes []string, cr *codecInfo, buf []byte) (interface{}, []byte, error) {
for _, name := range cr.allowedTypes {
for _, name := range allowedTypes {
if name == "null" {
// skip null since we know we already got type float64
continue
Expand All @@ -344,6 +344,14 @@ func checkAll(allowedTypes []string, cr *codecInfo, buf []byte) (interface{}, []
}
return nil, buf, fmt.Errorf("could not decode any json data in input %v", string(buf))
}

// sortedCopy returns a new slice that is a sorted copy of the provided types.
func sortedCopy(allowedTypes []string) []string {
local := make([]string, len(allowedTypes))
copy(local, allowedTypes)
sort.Strings(local)
return local
}
func nativeAvroFromTextualJSON(cr *codecInfo) func(buf []byte) (interface{}, []byte, error) {
return func(buf []byte) (interface{}, []byte, error) {

Expand Down Expand Up @@ -398,18 +406,14 @@ func nativeAvroFromTextualJSON(cr *codecInfo) func(buf []byte) (interface{}, []b
// longNativeFromTextual
// int
// intNativeFromTextual

// sorted so it would be
// double, float, int, long
// that makes the priorities right by chance
sort.Strings(cr.allowedTypes)
allowedTypes = sortedCopy(allowedTypes)

case map[string]interface{}:

// try to decode it as a map
// because a map should fail faster than a record
// if that fails assume record and return it
sort.Strings(cr.allowedTypes)
allowedTypes = sortedCopy(allowedTypes)
}

return checkAll(allowedTypes, cr, buf)
Expand Down
35 changes: 35 additions & 0 deletions union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,38 @@ func TestUnionJson(t *testing.T) {
testNativeToTextualJSONPass(t, `{"type":"record","name":"kubeEvents","fields":[{"name":"field1","type":"string","default":""},{"name":"field2","type":"string"}]}`, map[string]interface{}{"field1": "", "field2": "deef"}, []byte(`{"field1":"","field2":"deef"}`))

}

func TestStandardJSONFull_SimpleUnionMutationWouldMislabel(t *testing.T) {
// Minimal repro: union ["null","long"]. If allowedTypes were mutated (sorted)
// after a textual decode, the subsequent binary decode could label index 1 as "null".
codec, err := NewCodecForStandardJSONFull(`["null","long"]`)
if err != nil {
t.Fatal(err)
}
// Trigger textual path that previously sorted cr.allowedTypes.
if _, _, err := codec.NativeFromTextual([]byte("1")); err != nil {
t.Fatalf("textual decode failed: %v", err)
}
// Binary for union index 1 (long) with value 3: 0x02 0x06
datum, rest, err := codec.NativeFromBinary([]byte{0x02, 0x06})
if err != nil {
t.Fatalf("binary decode failed: %v", err)
}
if len(rest) != 0 {
t.Fatalf("unexpected trailing bytes: %d", len(rest))
}
m, ok := datum.(map[string]interface{})
if !ok {
t.Fatalf("expected union map, got %T", datum)
}
if _, bad := m["null"]; bad {
t.Fatalf("mis-labeled union: got key 'null', want 'long': %v", m)
}
v, ok := m["long"]
if !ok {
t.Fatalf("missing 'long' key: %v", m)
}
if v.(int64) != int64(3) {
t.Fatalf("wrong value: got %v want %v", v, int64(3))
}
}
Loading