Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecated resolver #117

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/runtime/protoimpl"
)

const (
newExtensionIndex = "1159"
previousExtensionIndex = "51071"
newExtensionIndex = "1159" // protovalidate versions >= v0.2.0
previousExtensionIndex = "51071" // protovalidate versions < v0.2.0
)

// DefaultResolver resolves protovalidate constraints options from descriptors.
Expand All @@ -34,29 +35,31 @@ type DefaultResolver struct{}
// ResolveMessageConstraints returns the MessageConstraints option set for the
// MessageDescriptor.
func (r DefaultResolver) ResolveMessageConstraints(desc protoreflect.MessageDescriptor) *validate.MessageConstraints {
constraints := resolveExt[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.MessageDescriptor, *validate.MessageConstraints](desc, validate.E_Message)
}
return constraints
return resolveConstraints[validate.MessageConstraints](desc, validate.E_Message)
}

// ResolveOneofConstraints returns the OneofConstraints option set for the
// OneofDescriptor.
func (r DefaultResolver) ResolveOneofConstraints(desc protoreflect.OneofDescriptor) *validate.OneofConstraints {
constraints := resolveExt[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.OneofDescriptor, *validate.OneofConstraints](desc, validate.E_Oneof)
}
return constraints
return resolveConstraints[validate.OneofConstraints](desc, validate.E_Oneof)
}

// ResolveFieldConstraints returns the FieldConstraints option set for the
// FieldDescriptor.
func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescriptor) *validate.FieldConstraints {
constraints := resolveExt[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field)
return resolveConstraints[validate.FieldConstraints](desc, validate.E_Field)
}

func resolveConstraints[C any, CP interface {
*C
proto.Message
}](
desc protoreflect.Descriptor,
extType *protoimpl.ExtensionInfo,
) (constraints CP) {
constraints = resolveExt[CP](desc.Options(), extType)
if constraints == nil {
constraints = resolveDeprecatedIndex[protoreflect.FieldDescriptor, *validate.FieldConstraints](desc, validate.E_Field)
constraints = resolveDeprecatedIndex[CP](desc.Options(), extType)
}
return constraints
}
Expand All @@ -66,13 +69,14 @@ func (r DefaultResolver) ResolveFieldConstraints(desc protoreflect.FieldDescript
// circumstances, particularly in dynamic or runtime contexts, the underlying
// extension value's type may be a dynamicpb.Message. In this case, we fall back
// through a proto.[Un]Marshal cycle to get it into the concrete type we expect.
func resolveExt[D protoreflect.Descriptor, C proto.Message](
desc D,
func resolveExt[C proto.Message](
options proto.Message,
extType protoreflect.ExtensionType,
) (constraints C) {
num := extType.TypeDescriptor().Number()
var msg proto.Message
proto.RangeExtensions(desc.Options(), func(typ protoreflect.ExtensionType, i interface{}) bool {

proto.RangeExtensions(options, func(typ protoreflect.ExtensionType, i interface{}) bool {
if num != typ.TypeDescriptor().Number() {
return true
}
Expand All @@ -93,16 +97,29 @@ func resolveExt[D protoreflect.Descriptor, C proto.Message](
}

// resolveDeprecatedIndex is a fallback for the deprecated extension index.
func resolveDeprecatedIndex[D protoreflect.Descriptor, C proto.Message](
desc D,
func resolveDeprecatedIndex[C proto.Message](
options proto.Message,
ext *protoimpl.ExtensionInfo,
) C {
return resolveExt[D, C](desc, &protoimpl.ExtensionInfo{
extInfo := &protoimpl.ExtensionInfo{
ExtendedType: ext.ExtendedType,
ExtensionType: ext.ExtensionType,
Field: 51071,
Name: ext.Name,
Tag: strings.Replace(ext.Tag, newExtensionIndex, previousExtensionIndex, 1),
Filename: ext.Filename,
})
}

// detect and handle if there are unknown options
if unknown := options.ProtoReflect().GetUnknown(); len(unknown) > 0 {
opts := options.ProtoReflect().Type().New()
resolver := &protoregistry.Types{}
if err := resolver.RegisterExtension(extInfo); err == nil {
if err = (&proto.UnmarshalOptions{Resolver: resolver}).Unmarshal(unknown, opts.Interface()); err == nil {
options = opts.Interface()
}
}
}

return resolveExt[C](options, extInfo)
}
Loading