diff --git a/schema/diff/diff.go b/schema/diff/diff.go index febb56c3e33f..9626ac1cf200 100644 --- a/schema/diff/diff.go +++ b/schema/diff/diff.go @@ -42,9 +42,8 @@ func CompareModuleSchemas(oldSchema, newSchema schema.ModuleSchema) ModuleSchema diff := ModuleSchemaDiff{} oldSchema.ObjectTypes(func(oldObj schema.ObjectType) bool { - newTyp, found := newSchema.LookupType(oldObj.Name) - newObj, typeMatch := newTyp.(schema.ObjectType) - if !found || !typeMatch { + newObj, found := newSchema.LookupObjectType(oldObj.Name) + if !found { diff.RemovedObjectTypes = append(diff.RemovedObjectTypes, oldObj) return true } @@ -56,18 +55,16 @@ func CompareModuleSchemas(oldSchema, newSchema schema.ModuleSchema) ModuleSchema }) newSchema.ObjectTypes(func(newObj schema.ObjectType) bool { - oldTyp, found := oldSchema.LookupType(newObj.TypeName()) - _, typeMatch := oldTyp.(schema.ObjectType) - if !found || !typeMatch { + _, found := oldSchema.LookupObjectType(newObj.TypeName()) + if !found { diff.AddedObjectTypes = append(diff.AddedObjectTypes, newObj) } return true }) oldSchema.EnumTypes(func(oldEnum schema.EnumType) bool { - newTyp, found := newSchema.LookupType(oldEnum.Name) - newEnum, typeMatch := newTyp.(schema.EnumType) - if !found || !typeMatch { + newEnum, found := newSchema.LookupEnumType(oldEnum.Name) + if !found { diff.RemovedEnumTypes = append(diff.RemovedEnumTypes, oldEnum) return true } @@ -79,9 +76,8 @@ func CompareModuleSchemas(oldSchema, newSchema schema.ModuleSchema) ModuleSchema }) newSchema.EnumTypes(func(newEnum schema.EnumType) bool { - oldTyp, found := oldSchema.LookupType(newEnum.TypeName()) - _, typeMatch := oldTyp.(schema.EnumType) - if !found || !typeMatch { + _, found := oldSchema.LookupEnumType(newEnum.TypeName()) + if !found { diff.AddedEnumTypes = append(diff.AddedEnumTypes, newEnum) } return true diff --git a/schema/field.go b/schema/field.go index af7374e367f7..0d762aa58dc6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -38,14 +38,11 @@ func (c Field) Validate(typeSet TypeSet) error { return fmt.Errorf("enum field %q must have a referenced type", c.Name) } - ty, ok := typeSet.LookupType(c.ReferencedType) + _, ok := typeSet.LookupEnumType(c.ReferencedType) if !ok { - return fmt.Errorf("enum field %q references unknown type %q", c.Name, c.ReferencedType) + return fmt.Errorf("can't find enum type %q referenced by field %q", c.ReferencedType, c.Name) } - if _, ok := ty.(EnumType); !ok { - return fmt.Errorf("enum field %q references non-enum type %q", c.Name, c.ReferencedType) - } default: if c.ReferencedType != "" { return fmt.Errorf("field %q with kind %q cannot have a referenced type", c.Name, c.Kind) @@ -72,14 +69,10 @@ func (c Field) ValidateValue(value interface{}, typeSet TypeSet) error { switch c.Kind { case EnumKind: - ty, ok := typeSet.LookupType(c.ReferencedType) + enumType, ok := typeSet.LookupEnumType(c.ReferencedType) if !ok { return fmt.Errorf("enum field %q references unknown type %q", c.Name, c.ReferencedType) } - enumType, ok := ty.(EnumType) - if !ok { - return fmt.Errorf("enum field %q references non-enum type %q", c.Name, c.ReferencedType) - } err := enumType.ValidateValue(value.(string)) if err != nil { return fmt.Errorf("invalid value for enum field %q: %v", c.Name, err) diff --git a/schema/module_schema.go b/schema/module_schema.go index dbd5365f5563..a7e395e46e8d 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -77,7 +77,33 @@ func (s ModuleSchema) LookupType(name string) (Type, bool) { return typ, ok } -// Types calls the provided function for each type in the module schema and stops if the function returns false. +// LookupEnumType is a convenience method that looks up an EnumType by name. +func (s ModuleSchema) LookupEnumType(name string) (t EnumType, found bool) { + typ, found := s.LookupType(name) + if !found { + return EnumType{}, false + } + t, ok := typ.(EnumType) + if !ok { + return EnumType{}, false + } + return t, true +} + +// LookupObjectType is a convenience method that looks up an ObjectType by name. +func (s ModuleSchema) LookupObjectType(name string) (t ObjectType, found bool) { + typ, found := s.LookupType(name) + if !found { + return ObjectType{}, false + } + t, ok := typ.(ObjectType) + if !ok { + return ObjectType{}, false + } + return t, true +} + +// AllTypes calls the provided function for each type in the module schema and stops if the function returns false. // The types are iterated over in sorted order by name. This function is compatible with go 1.23 iterators. func (s ModuleSchema) AllTypes(f func(Type) bool) { keys := make([]string, 0, len(s.types)) diff --git a/schema/module_schema_test.go b/schema/module_schema_test.go index a52085a60f5e..e0d6a3f22ad7 100644 --- a/schema/module_schema_test.go +++ b/schema/module_schema_test.go @@ -187,16 +187,11 @@ func TestModuleSchema_LookupType(t *testing.T) { }, }) - typ, ok := moduleSchema.LookupType("object1") + objectType, ok := moduleSchema.LookupObjectType("object1") if !ok { t.Fatalf("expected to find object type \"object1\"") } - objectType, ok := typ.(ObjectType) - if !ok { - t.Fatalf("expected object type, got %T", typ) - } - if objectType.Name != "object1" { t.Fatalf("expected object type name \"object1\", got %q", objectType.Name) } diff --git a/schema/testing/field.go b/schema/testing/field.go index 87154afec78e..a090c38395f8 100644 --- a/schema/testing/field.go +++ b/schema/testing/field.go @@ -20,10 +20,7 @@ var ( // FieldGen generates random Field's based on the validity criteria of fields. func FieldGen(typeSet schema.TypeSet) *rapid.Generator[schema.Field] { - enumTypes := slices.DeleteFunc(slices.Collect(typeSet.AllTypes), func(t schema.Type) bool { - _, ok := t.(schema.EnumType) - return !ok - }) + enumTypes := slices.Collect(typeSet.EnumTypes) enumTypeSelector := rapid.SampledFrom(enumTypes) return rapid.Custom(func(t *rapid.T) schema.Field { @@ -113,9 +110,8 @@ func baseFieldValue(field schema.Field, typeSet schema.TypeSet) *rapid.Generator case schema.AddressKind: return rapid.SliceOfN(rapid.Byte(), 20, 64).AsAny() case schema.EnumKind: - typ, found := typeSet.LookupType(field.ReferencedType) - enumTyp, ok := typ.(schema.EnumType) - if !found || !ok { + enumTyp, found := typeSet.LookupEnumType(field.ReferencedType) + if !found { panic(fmt.Errorf("enum type %q not found", field.ReferencedType)) } diff --git a/schema/type.go b/schema/type.go index f525e84b817b..9728675cd836 100644 --- a/schema/type.go +++ b/schema/type.go @@ -17,7 +17,13 @@ type Type interface { // Currently, the only implementation is ModuleSchema. type TypeSet interface { // LookupType looks up a type by name. - LookupType(name string) (Type, bool) + LookupType(name string) (t Type, found bool) + + // LookupEnumType is a convenience method that looks up an EnumType by name. + LookupEnumType(name string) (t EnumType, found bool) + + // LookupObjectType is a convenience method that looks up an ObjectType by name. + LookupObjectType(name string) (t ObjectType, found bool) // AllTypes calls the given function for each type in the type set. // This function is compatible with go 1.23 iterators and can be used like this: @@ -26,6 +32,14 @@ type TypeSet interface { // } AllTypes(f func(Type) bool) + // EnumTypes calls the given function for each EnumType in the type set. + // This function is compatible with go 1.23 iterators. + EnumTypes(f func(EnumType) bool) + + // ObjectTypes calls the given function for each ObjectType in the type set. + // This function is compatible with go 1.23 iterators. + ObjectTypes(f func(ObjectType) bool) + // isTypeSet is a private method that ensures that only types in this package can be marked as type sets. isTypeSet() } @@ -40,12 +54,22 @@ var emptyTypeSetInst = emptyTypeSet{} type emptyTypeSet struct{} -// LookupType always returns false because there are no types in an EmptyTypeSet. func (emptyTypeSet) LookupType(string) (Type, bool) { return nil, false } -// Types does nothing because there are no types in an EmptyTypeSet. +func (s emptyTypeSet) LookupEnumType(string) (t EnumType, found bool) { + return EnumType{}, false +} + +func (s emptyTypeSet) LookupObjectType(string) (t ObjectType, found bool) { + return ObjectType{}, false +} + func (emptyTypeSet) AllTypes(func(Type) bool) {} +func (s emptyTypeSet) EnumTypes(func(EnumType) bool) {} + +func (s emptyTypeSet) ObjectTypes(func(ObjectType) bool) {} + func (emptyTypeSet) isTypeSet() {}