Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions internal/impl/codec_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,71 @@ var coderMessageValue = valueCoderFuncs{
merge: mergeMessageValue,
}

func makeProtoStringerFieldCoder(ft reflect.Type) pointerCoderFuncs {
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
ps := p.AsValueOf(ft).Elem().Interface().(protoreflect.ProtoStringer)
if reflect.ValueOf(ps).IsNil() {
return 0
}
v, err := ps.ProtoString()
if err != nil {
panic(err)
}
return f.tagsize + protowire.SizeBytes(len(v))
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
ps := p.AsValueOf(ft).Elem().Interface().(protoreflect.ProtoStringer)
if reflect.ValueOf(ps).IsNil() {
return b, nil
}
v, err := ps.ProtoString()
if err != nil {
return nil, err
}
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendString(b, v)
return b, nil
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
d, n := protowire.ConsumeString(b)
if n < 0 {
return out, protowire.ParseError(n)
}
out.n = n
ps := p.AsValueOf(ft).Elem()
if ps.IsNil() {
ps.Set(reflect.New(ft.Elem()))
}
if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(d); err != nil {
return out, err
}
return out, nil
},
merge: func(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
srcp := src.AsValueOf(ft).Elem()
dstp := dst.AsValueOf(ft).Elem()
if srcp.IsNil() {
dstp.Set(reflect.Zero(ft))
} else {
if dstp.IsNil() {
dstp.Set(reflect.New(ft.Elem()))
}
v, err := srcp.Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
panic(err)
}
if err := dstp.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil {
panic(err)
}
}
},
}
}

func sizeGroupValue(v protoreflect.Value, tagsize int, opts marshalOptions) int {
m := v.Message().Interface()
return sizeGroup(m, tagsize, opts)
Expand Down Expand Up @@ -670,6 +735,62 @@ func isInitMessageSliceValue(listv protoreflect.Value) error {
return nil
}

func makeProtoStringerSliceFieldCoder(ft reflect.Type) pointerCoderFuncs {
return pointerCoderFuncs{
size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
n := 0
for _, v := range p.PointerSlice() {
v, err := v.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
panic(err)
}
n += f.tagsize + protowire.SizeBytes(len(v))
}
return n
},
marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
for _, v := range p.PointerSlice() {
v, err := v.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
return nil, err
}
b = protowire.AppendVarint(b, f.wiretag)
b = protowire.AppendString(b, v)
}
return b, nil
},
unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != protowire.BytesType {
return out, errUnknown
}
v, n := protowire.ConsumeString(b)
if n < 0 {
return out, protowire.ParseError(n)
}
ps := reflect.New(ft.Elem())
if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil {
return out, err
}
p.AppendPointerSlice(pointerOfValue(ps))
out.n = n
return out, nil
},
merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
for _, sp := range src.PointerSlice() {
ps := reflect.New(ft.Elem())
v, err := sp.AsValueOf(ft.Elem()).Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
panic(err)
}
if err := ps.Interface().(protoreflect.ProtoStringer).ParseProtoString(v); err != nil {
panic(err)
}
dst.AppendPointerSlice(pointerOfValue(ps))
}
},
}
}

var coderMessageSliceValue = valueCoderFuncs{
size: sizeMessageSliceValue,
marshal: appendMessageSliceValue,
Expand Down
6 changes: 6 additions & 0 deletions internal/impl/codec_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo,
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
return nil, coderBytesSliceValidateUTF8
}
if ft.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()) {
return nil, makeProtoStringerSliceFieldCoder(ft)
}
if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
return nil, coderBytesSlice
}
Expand Down Expand Up @@ -197,6 +200,9 @@ func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo,
return getMessageInfo(ft), makeMessageFieldCoder(fd, ft)
case fd.Kind() == protoreflect.GroupKind:
return getMessageInfo(ft), makeGroupFieldCoder(fd, ft)
case fd.Kind() == protoreflect.StringKind && ft.Kind() == reflect.Ptr &&
ft.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()):
return nil, makeProtoStringerFieldCoder(ft)
case !fd.HasPresence() && fd.ContainingOneof() == nil:
// Populated oneof fields always encode even if set to the zero value,
// which normally are not encoded in proto3.
Expand Down
38 changes: 38 additions & 0 deletions internal/impl/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ func newSingularConverter(t reflect.Type, fd protoreflect.FieldDescriptor) Conve
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &stringConverter{t, defVal(fd, stringZero)}
}
if t.Kind() == reflect.Ptr && t.Implements(reflect.TypeOf((*protoreflect.ProtoStringer)(nil)).Elem()) {
return &protoStringerConverter{t}
}
case protoreflect.BytesKind:
if t.Kind() == reflect.String || (t.Kind() == reflect.Slice && t.Elem() == byteType) {
return &bytesConverter{t, defVal(fd, bytesZero)}
Expand Down Expand Up @@ -340,6 +343,41 @@ func (c *stringConverter) IsValidGo(v reflect.Value) bool {
func (c *stringConverter) New() protoreflect.Value { return c.def }
func (c *stringConverter) Zero() protoreflect.Value { return c.def }

type protoStringerConverter struct {
goType reflect.Type
}

func (c *protoStringerConverter) PBValueOf(v reflect.Value) protoreflect.Value {
if v.Type() != c.goType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), c.goType))
}
s, err := v.Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
panic(err)
}
return protoreflect.ValueOfString(s)
}
func (c *protoStringerConverter) GoValueOf(v protoreflect.Value) reflect.Value {
goVal := reflect.New(c.goType.Elem()).Interface().(protoreflect.ProtoStringer)
if err := goVal.ParseProtoString(v.String()); err != nil {
panic(fmt.Sprintf("could not construct %v from %s", v.String(), c.goType))
}
return reflect.ValueOf(goVal)
}
func (c *protoStringerConverter) IsValidPB(v protoreflect.Value) bool {
_, ok := v.Interface().(string)
return ok
}
func (c *protoStringerConverter) IsValidGo(v reflect.Value) bool {
return v.IsValid() && v.Type() == c.goType
}
func (c *protoStringerConverter) New() protoreflect.Value {
return c.PBValueOf(reflect.New(c.goType.Elem()))
}
func (c *protoStringerConverter) Zero() protoreflect.Value {
return c.PBValueOf(reflect.Zero(c.goType))
}

type bytesConverter struct {
goType reflect.Type
def protoreflect.Value
Expand Down
12 changes: 11 additions & 1 deletion internal/impl/message_reflect_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,18 @@ func fieldInfoForScalar(fd protoreflect.FieldDescriptor, fs reflect.StructField,
return rv.Float() != 0 || math.Signbit(rv.Float())
case reflect.String, reflect.Slice:
return rv.Len() > 0
case reflect.Ptr:
if _, ok := rv.Interface().(protoreflect.ProtoStringer); ok {
return !rv.IsNil()
}
fallthrough
case reflect.Struct:
if _, ok := rv.Interface().(protoreflect.ProtoStringer); ok {
return !rv.IsNil()
}
fallthrough
default:
panic(fmt.Sprintf("field %v has invalid type: %v", fd.FullName(), rv.Type())) // should never happen
panic(fmt.Sprintf("field %v has invalid type: %v %s", fd.FullName(), rv.Type(), rv.Kind())) // should never happen
}
},
clear: func(p pointer) {
Expand Down
143 changes: 140 additions & 3 deletions proto/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
package proto

import (
"bytes"
"math"
"reflect"

"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/reflect/protoreflect"
)

Expand Down Expand Up @@ -51,7 +54,141 @@ func Equal(x, y Message) bool {
if mx.IsValid() != my.IsValid() {
return false
}
vx := protoreflect.ValueOfMessage(mx)
vy := protoreflect.ValueOfMessage(my)
return vx.Equal(vy)
return equalMessage(mx, my)
}

// equalMessage compares two messages.
func equalMessage(mx, my protoreflect.Message) bool {
if mx.Descriptor() != my.Descriptor() {
return false
}

nx := 0
equal := true
mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
nx++
vy := my.Get(fd)
equal = my.Has(fd) && equalField(fd, vx, vy)
return equal
})
if !equal {
return false
}
ny := 0
my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
ny++
return true
})
if nx != ny {
return false
}

return equalUnknown(mx.GetUnknown(), my.GetUnknown())
}

// equalField compares two fields.
func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch {
case fd.IsList():
return equalList(fd, x.List(), y.List())
case fd.IsMap():
return equalMap(fd, x.Map(), y.Map())
default:
return equalValue(fd, x, y)
}
}

// equalMap compares two maps.
func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool {
if x.Len() != y.Len() {
return false
}
equal := true
x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
vy := y.Get(k)
equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
return equal
})
return equal
}

// equalList compares two lists.
func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool {
if x.Len() != y.Len() {
return false
}
for i := x.Len() - 1; i >= 0; i-- {
if !equalValue(fd, x.Get(i), y.Get(i)) {
return false
}
}
return true
}

// equalValue compares two singular values.
func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch fd.Kind() {
case protoreflect.BoolKind:
return x.Bool() == y.Bool()
case protoreflect.EnumKind:
return x.Enum() == y.Enum()
case protoreflect.Int32Kind, protoreflect.Sint32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind:
return x.Int() == y.Int()
case protoreflect.Uint32Kind, protoreflect.Uint64Kind,
protoreflect.Fixed32Kind, protoreflect.Fixed64Kind:
return x.Uint() == y.Uint()
case protoreflect.FloatKind, protoreflect.DoubleKind:
fx := x.Float()
fy := y.Float()
if math.IsNaN(fx) || math.IsNaN(fy) {
return math.IsNaN(fx) && math.IsNaN(fy)
}
return fx == fy
case protoreflect.StringKind:
if psx, ok := x.Interface().(protoreflect.ProtoStringer); ok {
sx, err := psx.ProtoString()
if err != nil {
panic(err)
}
sy, err := y.Interface().(protoreflect.ProtoStringer).ProtoString()
if err != nil {
panic(err)
}
return sx == sy
}
return x.String() == y.String()
case protoreflect.BytesKind:
return bytes.Equal(x.Bytes(), y.Bytes())
case protoreflect.MessageKind, protoreflect.GroupKind:
return equalMessage(x.Message(), y.Message())
default:
return x.Interface() == y.Interface()
}
}

// equalUnknown compares unknown fields by direct comparison on the raw bytes
// of each individual field number.
func equalUnknown(x, y protoreflect.RawFields) bool {
if len(x) != len(y) {
return false
}
if bytes.Equal([]byte(x), []byte(y)) {
return true
}

mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
for len(x) > 0 {
fnum, _, n := protowire.ConsumeField(x)
mx[fnum] = append(mx[fnum], x[:n]...)
x = x[n:]
}
for len(y) > 0 {
fnum, _, n := protowire.ConsumeField(y)
my[fnum] = append(my[fnum], y[:n]...)
y = y[n:]
}
return reflect.DeepEqual(mx, my)
}
12 changes: 12 additions & 0 deletions reflect/protoreflect/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,15 @@ type Map interface {
// be preserved in marshaling or other operations.
IsValid() bool
}

// ProtoStringer is a Go struct mapped to string
// type in proto. Code gen tool may choose to use
// struct implementing this interface instead of
// string.
type ProtoStringer interface {
// ProtoString converts value to string
ProtoString() (string, error)

// ParseProtoString sets value of struct from string
ParseProtoString(string) error
}
Loading