Skip to content

Commit

Permalink
Fix deprecated resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
rodaine committed Apr 12, 2024
1 parent 8f25afb commit cbaad96
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoimpl"
)

const (
newExtensionIndex = "1159"
previousExtensionIndex = "51071"
newExtensionIndex = "1159" // protovalidate versions >= v0.2.0
previousExtensionIndex = "51071" // protovalidate versions < v0.2.0
)

// DefaultResolver resolves protovalidate constraints options from descriptors.
Expand All @@ -34,29 +35,31 @@ type DefaultResolver struct{}
// ResolveMessageConstraints returns the MessageConstraints option set for the
// MessageDescriptor.
func (r DefaultResolver) ResolveMessageConstraints(desc protoreflect.MessageDescriptor) *validate.MessageConstraints {
constraints := resolveExt[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message)
}
return constraints
return resolveConstraints[validate.MessageConstraints](desc, validate.E_Message)
}

// ResolveOneofConstraints returns the OneofConstraints option set for the
// OneofDescriptor.
func (r DefaultResolver) ResolveOneofConstraints(desc protoreflect.OneofDescriptor) *validate.OneofConstraints {
constraints := resolveExt[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof)
}
return constraints
return resolveConstraints[validate.OneofConstraints](desc, validate.E_Oneof)
}

// ResolveFieldConstraints returns the FieldConstraints option set for the
// FieldDescriptor.
func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescriptor) *validate.FieldConstraints {
constraints := resolveExt[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field)
return resolveConstraints[validate.FieldConstraints](desc, validate.E_Field)
}

func resolveConstraints[C any, CP interface {
*C
proto.Message
}](
desc protoreflect.Descriptor,
extType *protoimpl.ExtensionInfo,
) (constraints CP) {
constraints = resolveExt[CP](desc.Options(), extType)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field)
constraints = resolveDeprecatedIndex[CP](desc.Options(), extType)
}
return constraints
}
Expand All @@ -66,13 +69,14 @@ func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescript
// circumstances, particularly in dynamic or runtime contexts, the underlying
// extension value's type may be a dynamicpb.Message. In this case, we fall back
// through a proto.[Un]Marshal cycle to get it into the concrete type we expect.
func resolveExt[D protoreflect.Descriptor, C proto.Message](
desc D,
func resolveExt[C proto.Message](
options proto.Message,
extType protoreflect.ExtensionType,
) (constraints C) {
num := extType.TypeDescriptor().Number()
var msg proto.Message
proto.RangeExtensions(desc.Options(), func(typ protoreflect.ExtensionType, i interface{}) bool {

proto.RangeExtensions(options, func(typ protoreflect.ExtensionType, i interface{}) bool {
if num != typ.TypeDescriptor().Number() {
return true
}
Expand All @@ -93,16 +97,29 @@ func resolveExt[D protoreflect.Descriptor, C proto.Message](
}

// resolveDeprecatedIndex is a fallback for the deprecated extension index.
func resolveDeprecatedIndex[D protoreflect.Descriptor, C proto.Message](
desc D,
func resolveDeprecatedIndex[C proto.Message](
options proto.Message,
ext *protoimpl.ExtensionInfo,
) C {
return resolveExt[D, C](desc, &protoimpl.ExtensionInfo{
extInfo := &protoimpl.ExtensionInfo{
ExtendedType: ext.ExtendedType,
ExtensionType: ext.ExtensionType,
Field: 51071,
Name: ext.Name,
Tag: strings.Replace(ext.Tag, newExtensionIndex, previousExtensionIndex, 1),
Filename: ext.Filename,
})
}

// detect and handle if there are unknown options
if unknown := options.ProtoReflect().GetUnknown(); len(unknown) > 0 {
opts := options.ProtoReflect().Type().New()
resolver := &protoregistry.Types{}
if err := resolver.RegisterExtension(extInfo); err == nil {
if err = (&proto.UnmarshalOptions{Resolver: resolver}).Unmarshal(unknown, opts.Interface()); err == nil {
options = opts.Interface()
}
}
}

return resolveExt[C](options, extInfo)
}

0 comments on commit cbaad96

Please sign in to comment.