From 2c3795b761ae5545d5cd5b993f1a629f7a89abc4 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Wed, 5 Jun 2024 23:47:45 -0700 Subject: [PATCH 1/6] ProtoStringer interface type --- reflect/protoreflect/value.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go index a7b0d06ff..6d586446c 100644 --- a/reflect/protoreflect/value.go +++ b/reflect/protoreflect/value.go @@ -283,3 +283,15 @@ type Map interface { // be preserved in marshaling or other operations. IsValid() bool } + +// ProtoStringer is a Go struct mapped to string +// type in proto. Code gen tool may choose to use +// struct implementing this interface instead of +// string. +type ProtoStringer interface { + // ProtoString converts value to string + ProtoString() (string, error) + + // ParseProtoString sets value of struct from string + ParseProtoString(string) error +} From 0b1972b97cd9d1f27a057a27bab0138b53157a15 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Thu, 6 Jun 2024 00:06:40 -0700 Subject: [PATCH 2/6] ValueOfProtoString function --- reflect/protoreflect/value_union.go | 50 ++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/reflect/protoreflect/value_union.go b/reflect/protoreflect/value_union.go index 160309731..d37e3bb96 100644 --- a/reflect/protoreflect/value_union.go +++ b/reflect/protoreflect/value_union.go @@ -13,21 +13,21 @@ import ( // The Value is used to represent all possible values a field may take. // The following shows which Go type is used to represent each proto [Kind]: // -// ╔════════════╤═════════════════════════════════════╗ -// ║ Go type │ Protobuf kind ║ -// ╠════════════╪═════════════════════════════════════╣ -// ║ bool │ BoolKind ║ -// ║ int32 │ Int32Kind, Sint32Kind, Sfixed32Kind ║ -// ║ int64 │ Int64Kind, Sint64Kind, Sfixed64Kind ║ -// ║ uint32 │ Uint32Kind, Fixed32Kind ║ -// ║ uint64 │ Uint64Kind, Fixed64Kind ║ -// ║ float32 │ FloatKind ║ -// ║ float64 │ DoubleKind ║ -// ║ string │ StringKind ║ -// ║ []byte │ BytesKind ║ -// ║ EnumNumber │ EnumKind ║ -// ║ Message │ MessageKind, GroupKind ║ -// ╚════════════╧═════════════════════════════════════╝ +// ╔═══════════════════════╤═════════════════════════════════════╗ +// ║ Go type │ Protobuf kind ║ +// ╠═══════════════════════╪═════════════════════════════════════╣ +// ║ bool │ BoolKind ║ +// ║ int32 │ Int32Kind, Sint32Kind, Sfixed32Kind ║ +// ║ int64 │ Int64Kind, Sint64Kind, Sfixed64Kind ║ +// ║ uint32 │ Uint32Kind, Fixed32Kind ║ +// ║ uint64 │ Uint64Kind, Fixed64Kind ║ +// ║ float32 │ FloatKind ║ +// ║ float64 │ DoubleKind ║ +// ║ string, ProtoStringer │ StringKind ║ +// ║ []byte │ BytesKind ║ +// ║ EnumNumber │ EnumKind ║ +// ║ Message │ MessageKind, GroupKind ║ +// ╚═══════════════════════╧═════════════════════════════════════╝ // // Multiple protobuf Kinds may be represented by a single Go type if the type // can losslessly represent the information for the proto kind. For example, @@ -111,6 +111,8 @@ func ValueOf(v interface{}) Value { return ValueOfEnum(v) case Message, List, Map: return valueOfIface(v) + case ProtoStringer: + return valueOfIface(v) case ProtoMessage: panic(fmt.Sprintf("invalid proto.Message(%T) type, expected a protoreflect.Message type", v)) default: @@ -187,6 +189,11 @@ func ValueOfMap(v Map) Value { return valueOfIface(v) } +// ValueOfProtoString returns a new ProtoStringer value. +func ValueOfProtoString(v ProtoStringer) Value { + return valueOfIface(v) +} + // IsValid reports whether v is populated with a value. func (v Value) IsValid() bool { return v.typ != nilType @@ -256,6 +263,8 @@ func (v Value) typeName() string { return "list" case Map: return "map" + case ProtoStringer: + return "protostringer" default: return fmt.Sprintf("", v) } @@ -347,6 +356,17 @@ func (v Value) Message() Message { } } +// [ProtoStringer] returns v as a [ProtoStringer] and panics if the type is not +// a [ProtoStringer]. +func (v Value) ProtoStringer() ProtoStringer { + switch vi := v.getIface().(type) { + case ProtoStringer: + return vi + default: + panic(v.panicMessage("protostringer")) + } +} + // List returns v as a [List] and panics if the type is not a [List]. func (v Value) List() List { switch vi := v.getIface().(type) { From 5fe152fdf37c68284c577c0e008e1d453d80a112 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Thu, 6 Jun 2024 00:44:06 -0700 Subject: [PATCH 3/6] internal/impl: Add protoStringerConverter in newSingularConverter --- internal/impl/convert.go | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/internal/impl/convert.go b/internal/impl/convert.go index 185ef2efa..6e9bcdc74 100644 --- a/internal/impl/convert.go +++ b/internal/impl/convert.go @@ -127,6 +127,9 @@ func newSingularConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Conve if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { return &stringConverter{t, defVal(fd, stringZero)} } + if t.Kind() == reflect.Ptr && t.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()) { + return &protoStringerConverter{t} + } case protoreflect.BytesKind: if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) { return &bytesConverter{t, defVal(fd, bytesZero)} @@ -340,6 +343,41 @@ func (c *stringConverter) IsValidGo(v reflect.Value) bool { func (c *stringConverter) New() protoreflect.Value { return c.def } func (c *stringConverter) Zero() protoreflect.Value { return c.def } +type protoStringerConverter struct { + goType reflect.Type +} + +func (c *protoStringerConverter) PBValueOf(v reflect.Value) protoreflect.Value { + if v.Type() != c.goType { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType)) + } + s, err := v.Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + panic(err) + } + return protoreflect.ValueOfString(s) +} +func (c *protoStringerConverter) GoValueOf(v protoreflect.Value) reflect.Value { + goVal := reflect.New(c.goType.Elem()).Interface().(protoreflect.ProtoStringer) + if err := goVal.ParseProtoString(v.String()); err != nil { + panic(fmt.Sprintf("could not construct %v from %s", v.String(), c.goType)) + } + return reflect.ValueOf(goVal) +} +func (c *protoStringerConverter) IsValidPB(v protoreflect.Value) bool { + _, ok := v.Interface().(string) + return ok +} +func (c *protoStringerConverter) IsValidGo(v reflect.Value) bool { + return v.IsValid() && v.Type() == c.goType +} +func (c *protoStringerConverter) New() protoreflect.Value { + return c.PBValueOf(reflect.New(c.goType.Elem())) +} +func (c *protoStringerConverter) Zero() protoreflect.Value { + return c.PBValueOf(reflect.Zero(c.goType)) +} + type bytesConverter struct { goType reflect.Type def protoreflect.Value From dc9ab30349a7e053a4c254f3d13272aa8352f340 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Thu, 6 Jun 2024 01:15:08 -0700 Subject: [PATCH 4/6] internal/impl: Add ProtoStringer in fieldCoder --- internal/impl/codec_field.go | 121 ++++++++++++++++++++++++++++++++++ internal/impl/codec_tables.go | 6 ++ 2 files changed, 127 insertions(+) diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go index 78ee47e44..2d2f22c0b 100644 --- a/internal/impl/codec_field.go +++ b/internal/impl/codec_field.go @@ -334,6 +334,71 @@ var coderMessageValue = valueCoderFuncs{ merge: mergeMessageValue, } +func makeProtoStringerFieldCoder(ft reflect.Type) pointerCoderFuncs { + return pointerCoderFuncs{ + size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { + ps := p.AsValueOf(ft).Elem().Interface().(protoreflect.ProtoStringer) + if reflect.ValueOf(ps).IsNil() { + return 0 + } + v, err := ps.ProtoString() + if err != nil { + panic(err) + } + return f.tagsize + protowire.SizeBytes(len(v)) + }, + marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + ps := p.AsValueOf(ft).Elem().Interface().(protoreflect.ProtoStringer) + if reflect.ValueOf(ps).IsNil() { + return b, nil + } + v, err := ps.ProtoString() + if err != nil { + return nil, err + } + b = protowire.AppendVarint(b, f.wiretag) + b = protowire.AppendString(b, v) + return b, nil + }, + unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.BytesType { + return out, errUnknown + } + d, n := protowire.ConsumeString(b) + if n < 0 { + return out, protowire.ParseError(n) + } + out.n = n + ps := p.AsValueOf(ft).Elem() + if ps.IsNil() { + ps.Set(reflect.New(ft.Elem())) + } + if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(d); err != nil { + return out, err + } + return out, nil + }, + merge: func(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) { + srcp := src.AsValueOf(ft).Elem() + dstp := dst.AsValueOf(ft).Elem() + if srcp.IsNil() { + dstp.Set(reflect.Zero(ft)) + } else { + if dstp.IsNil() { + dstp.Set(reflect.New(ft.Elem())) + } + v, err := srcp.Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + panic(err) + } + if err := dstp.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil { + panic(err) + } + } + }, + } +} + func sizeGroupValue(v protoreflect.Value, tagsize int, opts marshalOptions) int { m := v.Message().Interface() return sizeGroup(m, tagsize, opts) @@ -670,6 +735,62 @@ func isInitMessageSliceValue(listv protoreflect.Value) error { return nil } +func makeProtoStringerSliceFieldCoder(ft reflect.Type) pointerCoderFuncs { + return pointerCoderFuncs{ + size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { + n := 0 + for _, v := range p.PointerSlice() { + v, err := v.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + panic(err) + } + n += f.tagsize + protowire.SizeBytes(len(v)) + } + return n + }, + marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { + for _, v := range p.PointerSlice() { + v, err := v.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + return nil, err + } + b = protowire.AppendVarint(b, f.wiretag) + b = protowire.AppendString(b, v) + } + return b, nil + }, + unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { + if wtyp != protowire.BytesType { + return out, errUnknown + } + v, n := protowire.ConsumeString(b) + if n < 0 { + return out, protowire.ParseError(n) + } + ps := reflect.New(ft.Elem()) + if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil { + return out, err + } + p.AppendPointerSlice(pointerOfValue(ps)) + out.n = n + return out, nil + }, + merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { + for _, sp := range src.PointerSlice() { + ps := reflect.New(ft.Elem()) + v, err := sp.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + panic(err) + } + if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil { + panic(err) + } + dst.AppendPointerSlice(pointerOfValue(ps)) + } + }, + } +} + var coderMessageSliceValue = valueCoderFuncs{ size: sizeMessageSliceValue, marshal: appendMessageSliceValue, diff --git a/internal/impl/codec_tables.go b/internal/impl/codec_tables.go index 13077751e..bec358f89 100644 --- a/internal/impl/codec_tables.go +++ b/internal/impl/codec_tables.go @@ -111,6 +111,9 @@ func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { return nil, coderBytesSliceValidateUTF8 } + if ft.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()) { + return nil, makeProtoStringerSliceFieldCoder(ft) + } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytesSlice } @@ -197,6 +200,9 @@ func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, return getMessageInfo(ft), makeMessageFieldCoder(fd, ft) case fd.Kind() == protoreflect.GroupKind: return getMessageInfo(ft), makeGroupFieldCoder(fd, ft) + case fd.Kind() == protoreflect.StringKind && ft.Kind() == reflect.Ptr && + ft.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()): + return nil, makeProtoStringerFieldCoder(ft) case !fd.HasPresence() && fd.ContainingOneof() == nil: // Populated oneof fields always encode even if set to the zero value, // which normally are not encoded in proto3. From a5c2af898127ebac7f9c18fe85cf18c777771828 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Thu, 6 Jun 2024 01:40:11 -0700 Subject: [PATCH 5/6] proto: Extend proto.Equal() --- proto/equal.go | 143 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 140 insertions(+), 3 deletions(-) diff --git a/proto/equal.go b/proto/equal.go index 1a0be1b03..2cebc28fa 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -5,8 +5,11 @@ package proto import ( + "bytes" + "math" "reflect" + "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -51,7 +54,141 @@ func Equal(x, y Message) bool { if mx.IsValid() != my.IsValid() { return false } - vx := protoreflect.ValueOfMessage(mx) - vy := protoreflect.ValueOfMessage(my) - return vx.Equal(vy) + return equalMessage(mx, my) +} + +// equalMessage compares two messages. +func equalMessage(mx, my protoreflect.Message) bool { + if mx.Descriptor() != my.Descriptor() { + return false + } + + nx := 0 + equal := true + mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { + nx++ + vy := my.Get(fd) + equal = my.Has(fd) && equalField(fd, vx, vy) + return equal + }) + if !equal { + return false + } + ny := 0 + my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool { + ny++ + return true + }) + if nx != ny { + return false + } + + return equalUnknown(mx.GetUnknown(), my.GetUnknown()) +} + +// equalField compares two fields. +func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { + switch { + case fd.IsList(): + return equalList(fd, x.List(), y.List()) + case fd.IsMap(): + return equalMap(fd, x.Map(), y.Map()) + default: + return equalValue(fd, x, y) + } +} + +// equalMap compares two maps. +func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool { + if x.Len() != y.Len() { + return false + } + equal := true + x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool { + vy := y.Get(k) + equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) + return equal + }) + return equal +} + +// equalList compares two lists. +func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool { + if x.Len() != y.Len() { + return false + } + for i := x.Len() - 1; i >= 0; i-- { + if !equalValue(fd, x.Get(i), y.Get(i)) { + return false + } + } + return true +} + +// equalValue compares two singular values. +func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool { + switch fd.Kind() { + case protoreflect.BoolKind: + return x.Bool() == y.Bool() + case protoreflect.EnumKind: + return x.Enum() == y.Enum() + case protoreflect.Int32Kind, protoreflect.Sint32Kind, + protoreflect.Int64Kind, protoreflect.Sint64Kind, + protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind: + return x.Int() == y.Int() + case protoreflect.Uint32Kind, protoreflect.Uint64Kind, + protoreflect.Fixed32Kind, protoreflect.Fixed64Kind: + return x.Uint() == y.Uint() + case protoreflect.FloatKind, protoreflect.DoubleKind: + fx := x.Float() + fy := y.Float() + if math.IsNaN(fx) || math.IsNaN(fy) { + return math.IsNaN(fx) && math.IsNaN(fy) + } + return fx == fy + case protoreflect.StringKind: + if psx, ok := x.Interface().(protoreflect.ProtoStringer); ok { + sx, err := psx.ProtoString() + if err != nil { + panic(err) + } + sy, err := y.Interface().(protoreflect.ProtoStringer).ProtoString() + if err != nil { + panic(err) + } + return sx == sy + } + return x.String() == y.String() + case protoreflect.BytesKind: + return bytes.Equal(x.Bytes(), y.Bytes()) + case protoreflect.MessageKind, protoreflect.GroupKind: + return equalMessage(x.Message(), y.Message()) + default: + return x.Interface() == y.Interface() + } +} + +// equalUnknown compares unknown fields by direct comparison on the raw bytes +// of each individual field number. +func equalUnknown(x, y protoreflect.RawFields) bool { + if len(x) != len(y) { + return false + } + if bytes.Equal([]byte(x), []byte(y)) { + return true + } + + mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields) + my := make(map[protoreflect.FieldNumber]protoreflect.RawFields) + for len(x) > 0 { + fnum, _, n := protowire.ConsumeField(x) + mx[fnum] = append(mx[fnum], x[:n]...) + x = x[n:] + } + for len(y) > 0 { + fnum, _, n := protowire.ConsumeField(y) + my[fnum] = append(my[fnum], y[:n]...) + y = y[n:] + } + return reflect.DeepEqual(mx, my) } From d3d592918f66982fdcb9fe4cb4135c2542cb2376 Mon Sep 17 00:00:00 2001 From: Keith Noguchi Date: Thu, 6 Jun 2024 01:49:35 -0700 Subject: [PATCH 6/6] internal/impl: Extend fieldInfoForScalar() with ProtoStringer --- internal/impl/message_reflect_field.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/internal/impl/message_reflect_field.go b/internal/impl/message_reflect_field.go index 986322b19..512364d81 100644 --- a/internal/impl/message_reflect_field.go +++ b/internal/impl/message_reflect_field.go @@ -292,8 +292,18 @@ func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField, return rv.Float() != 0 || math.Signbit(rv.Float()) case reflect.String, reflect.Slice: return rv.Len() > 0 + case reflect.Ptr: + if _, ok := rv.Interface().(protoreflect.ProtoStringer); ok { + return !rv.IsNil() + } + fallthrough + case reflect.Struct: + if _, ok := rv.Interface().(protoreflect.ProtoStringer); ok { + return !rv.IsNil() + } + fallthrough default: - panic(fmt.Sprintf("field %v has invalid type: %v", fd.FullName(), rv.Type())) // should never happen + panic(fmt.Sprintf("field %v has invalid type: %v %s", fd.FullName(), rv.Type(), rv.Kind())) // should never happen } }, clear: func(p pointer) {