diff --git a/union.go b/union.go index 031e84f..ac5a18f 100644 --- a/union.go +++ b/union.go @@ -327,7 +327,7 @@ func buildCodecForTypeDescribedBySliceTwoWayJSON(st map[string]*Codec, enclosing } func checkAll(allowedTypes []string, cr *codecInfo, buf []byte) (interface{}, []byte, error) { - for _, name := range cr.allowedTypes { + for _, name := range allowedTypes { if name == "null" { // skip null since we know we already got type float64 continue @@ -344,6 +344,14 @@ func checkAll(allowedTypes []string, cr *codecInfo, buf []byte) (interface{}, [] } return nil, buf, fmt.Errorf("could not decode any json data in input %v", string(buf)) } + +// sortedCopy returns a new slice that is a sorted copy of the provided types. +func sortedCopy(allowedTypes []string) []string { + local := make([]string, len(allowedTypes)) + copy(local, allowedTypes) + sort.Strings(local) + return local +} func nativeAvroFromTextualJSON(cr *codecInfo) func(buf []byte) (interface{}, []byte, error) { return func(buf []byte) (interface{}, []byte, error) { @@ -398,18 +406,14 @@ func nativeAvroFromTextualJSON(cr *codecInfo) func(buf []byte) (interface{}, []b // longNativeFromTextual // int // intNativeFromTextual - - // sorted so it would be - // double, float, int, long - // that makes the priorities right by chance - sort.Strings(cr.allowedTypes) + allowedTypes = sortedCopy(allowedTypes) case map[string]interface{}: // try to decode it as a map // because a map should fail faster than a record // if that fails assume record and return it - sort.Strings(cr.allowedTypes) + allowedTypes = sortedCopy(allowedTypes) } return checkAll(allowedTypes, cr, buf) diff --git a/union_test.go b/union_test.go index b66884f..9f87f97 100644 --- a/union_test.go +++ b/union_test.go @@ -339,3 +339,38 @@ func TestUnionJson(t *testing.T) { testNativeToTextualJSONPass(t, `{"type":"record","name":"kubeEvents","fields":[{"name":"field1","type":"string","default":""},{"name":"field2","type":"string"}]}`, map[string]interface{}{"field1": "", "field2": "deef"}, []byte(`{"field1":"","field2":"deef"}`)) } + +func TestStandardJSONFull_SimpleUnionMutationWouldMislabel(t *testing.T) { + // Minimal repro: union ["null","long"]. If allowedTypes were mutated (sorted) + // after a textual decode, the subsequent binary decode could label index 1 as "null". + codec, err := NewCodecForStandardJSONFull(`["null","long"]`) + if err != nil { + t.Fatal(err) + } + // Trigger textual path that previously sorted cr.allowedTypes. + if _, _, err := codec.NativeFromTextual([]byte("1")); err != nil { + t.Fatalf("textual decode failed: %v", err) + } + // Binary for union index 1 (long) with value 3: 0x02 0x06 + datum, rest, err := codec.NativeFromBinary([]byte{0x02, 0x06}) + if err != nil { + t.Fatalf("binary decode failed: %v", err) + } + if len(rest) != 0 { + t.Fatalf("unexpected trailing bytes: %d", len(rest)) + } + m, ok := datum.(map[string]interface{}) + if !ok { + t.Fatalf("expected union map, got %T", datum) + } + if _, bad := m["null"]; bad { + t.Fatalf("mis-labeled union: got key 'null', want 'long': %v", m) + } + v, ok := m["long"] + if !ok { + t.Fatalf("missing 'long' key: %v", m) + } + if v.(int64) != int64(3) { + t.Fatalf("wrong value: got %v want %v", v, int64(3)) + } +}