diff --git a/schema/schema.go b/schema/schema.go index b26a266..1e26203 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "regexp" "strings" "time" @@ -12,7 +13,8 @@ import ( ) var ( - _timeType = reflect.TypeOf(time.Time{}) + _timeType = reflect.TypeOf(time.Time{}) + _genericTypeRegex = regexp.MustCompile(`(.+)\[(.+)\]`) ) type Schema struct { @@ -161,25 +163,36 @@ func typeName(t reflect.Type) string { for t.Kind() == reflect.Ptr { t = t.Elem() } - return t.String() + + typeName := t.String() + match := _genericTypeRegex.FindStringSubmatch(typeName) + if match == nil { + return typeName + } else { + return fmt.Sprintf("%s..%s", match[1], pkgName(match[2])) + } } func structReference(t reflect.Type) string { return fmt.Sprintf("#/components/schemas/%s", typeName(t)) } -func pkgName(t reflect.Type) string { - parts := strings.Split(t.PkgPath(), "/") +func pkgName(p string) string { + parts := strings.Split(p, "/") return parts[len(parts)-1] } +func typePkgName(t reflect.Type) string { + return pkgName(t.PkgPath()) +} + func (s *Schema) generateStructureSchema(ctx context.Context, doc *openapi3.T, t reflect.Type, inlineLevel int, fieldInfo shared.AttributeInfo, filterObject shared.FilterInterface) (*openapi3.Schema, error) { ret := &openapi3.Schema{ Type: "object", } - pkgName := shared.ToSnakeCase(pkgName(t)) + pkgName := shared.ToSnakeCase(typePkgName(t)) structName := shared.ToSnakeCase(t.Name()) fieldInfo = fieldInfo.AppendPath(structName) @@ -197,6 +210,7 @@ func (s *Schema) generateStructureSchema(ctx context.Context, doc *openapi3.T, t for i := 0; i < t.NumField(); i++ { f := t.Field(i) + fTypeName := typeName(f.Type) tag := ParseJsonTag(f) if (tag.Ignored != nil) && *tag.Ignored { @@ -218,8 +232,8 @@ func (s *Schema) generateStructureSchema(ctx context.Context, doc *openapi3.T, t } //Detect if field is anonymous, look into the schemas and use the same property - if f.Anonymous && fieldSchema.Ref != "" && doc.Components.Schemas[f.Type.String()] != nil && doc.Components.Schemas[f.Type.String()].Value != nil { - for name, property := range doc.Components.Schemas[f.Type.String()].Value.Properties { + if f.Anonymous && fieldSchema.Ref != "" && doc.Components.Schemas[fTypeName] != nil && doc.Components.Schemas[fTypeName].Value != nil { + for name, property := range doc.Components.Schemas[fTypeName].Value.Properties { ret.WithPropertyRef(name, property) } diff --git a/schema/schema_test.go b/schema/schema_test.go index 2afbf65..0e21c1c 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -28,6 +28,14 @@ type RecursiveUser struct { Group *RecursiveGroup } +type Generic[T any] struct { + Name string + Value T +} + +type Embedded struct { +} + func checkGeneratedType(g *goblin.G, ctx context.Context, schemaPtr **Schema, docPtr **openapi3.T, value interface{}, expected string) { g.It(fmt.Sprintf("should generate inline type for %T", value), func() { s := *schemaPtr @@ -311,6 +319,41 @@ func TestSchema(t *testing.T) { "type": "string", "format": "date-time" }`) + + g.It("should generate valid name for generic structures", func() { + st := Generic[Embedded]{ + Name: "john", + } + typ := reflect.TypeOf(&st) + schema, err := s.GenerateSchemaFor(ctx, doc, typ) + require.NoError(g, err) + + data, err := json.Marshal(schema) + require.NoError(g, err) + // check returned schema + assert.JSONEq(g, `{ + "$ref": "#/components/schemas/schema.Generic..schema.Embedded" + }`, string(data)) + + userSchema, found := doc.Components.Schemas[typeName(typ.Elem())] + require.True(g, found) + + data, err = json.Marshal(userSchema) + require.NoError(g, err) + + // check returned schema + assert.JSONEq(g, `{ + "type": "object", + "properties": { + "Name": { + "type": "string" + }, + "Value": { + "$ref": "#/components/schemas/schema.Embedded" + } + } + }`, string(data)) + }) }) })