Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions jsonpb/jsonpb.go
Original file line number Diff line number Diff line change
Expand Up @@ -1103,13 +1103,26 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe
}
// Check for any oneof fields.
if len(jsonFields) > 0 {
for _, oop := range sprops.OneofTypes {
oneofKeys := make([]string, 0, len(sprops.OneofTypes))
for k := range sprops.OneofTypes {
oneofKeys = append(oneofKeys, k)
}
slices.Sort(oneofKeys)

setOneofFields := make(map[int]bool)
for _, key := range oneofKeys {
oop := sprops.OneofTypes[key]
raw, ok := consumeField(oop.Prop)
if !ok {
continue
}
if setOneofFields[oop.Field] {
return fmt.Errorf("field %q would overwrite already-set oneof %q in %v",
oop.Prop.OrigName, targetType.Field(oop.Field).Name, targetType)
}
nv := reflect.New(oop.Type.Elem())
target.Field(oop.Field).Set(nv)
setOneofFields[oop.Field] = true
if err := u.unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil {
return err
}
Expand All @@ -1118,7 +1131,14 @@ func (u *Unmarshaler) unmarshalValue(target reflect.Value, inputValue json.RawMe
// Handle proto2 extensions.
if len(jsonFields) > 0 {
if ep, ok := target.Addr().Interface().(proto.Message); ok {
for _, ext := range proto.RegisteredExtensions(ep) {
extensions := proto.RegisteredExtensions(ep)
extIDs := make([]int32, 0, len(extensions))
for id := range extensions {
extIDs = append(extIDs, id)
}
sort.Sort(int32Slice(extIDs))
for _, id := range extIDs {
ext := extensions[id]
name := fmt.Sprintf("[%s]", ext.Name)
raw, ok := jsonFields[name]
if !ok {
Expand Down
25 changes: 25 additions & 0 deletions jsonpb/jsonpb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,8 @@ var unmarshalingShouldError = []struct {
{"StringValue containing invalid character", `{"str": "\U00004E16\U0000754C"}`, &pb.KnownTypes{}},
{"StructValue containing invalid character", `{"str": "\U00004E16\U0000754C"}`, &types.Struct{}},
{"repeated proto3 enum with non array input", `{"rFunny":"PUNS"}`, &proto3pb.Message{RFunny: []proto3pb.Message_Humour{}}},
{"oneof conflict: two fields from the same group", `{"title":"foo","salary":31000}`, new(pb.MsgWithOneof)},
{"oneof conflict: camelCase and orig_name from the same group", `{"Country":"Australia","homeAddress":"Brisbane"}`, new(pb.MsgWithOneof)},
}

func TestUnmarshalingBadInput(t *testing.T) {
Expand All @@ -972,6 +974,29 @@ func TestUnmarshalingBadInput(t *testing.T) {
}
}

// TestUnmarshalOneofConflictDeterminism verifies that unmarshalling a JSON
// object containing multiple keys from the same oneof group always produces a
// consistent result. Before the fix the outcome depended on random map-
// iteration order, so different runs could decode different oneof variants.
func TestUnmarshalOneofConflictDeterminism(t *testing.T) {
const runs = 100
seen := make(map[string]struct{})
for range runs {
var msg pb.MsgWithOneof
err := UnmarshalString(`{"title":"foo","salary":31000}`, &msg)
var key string
if err != nil {
key = "error:" + err.Error()
} else {
key = proto.MarshalTextString(&msg)
}
seen[key] = struct{}{}
}
if len(seen) > 1 {
t.Errorf("non-deterministic unmarshal: got %d different outcomes over %d runs", len(seen), runs)
}
}

type funcResolver func(turl string) (proto.Message, error)

func (fn funcResolver) Resolve(turl string) (proto.Message, error) {
Expand Down