Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/danielgtaylor/huma/v2/validation"
Expand Down Expand Up @@ -676,6 +677,29 @@ type SchemaTransformer interface {
TransformSchema(r Registry, s *Schema) *Schema
}

type SchemaFactory func(Registry) *Schema

var (
schemaFactoryMu sync.RWMutex
schemaFactories = map[reflect.Type]SchemaFactory{}
)

// RegisterTypeSchema associates a schema factory with the given type.
// The provided factory runs whenever SchemaFromType handles that type.
// Later calls replace any previously registered factory.
func RegisterTypeSchema(t reflect.Type, factory SchemaFactory) {
if t == nil {
panic("huma: RegisterTypeSchema called with nil type")
}
if factory == nil {
panic("huma: RegisterTypeSchema called with nil factory")
}

schemaFactoryMu.Lock()
schemaFactories[t] = factory
schemaFactoryMu.Unlock()
}

// SchemaFromType returns a schema for a given type, using the registry to
// possibly create references for nested structs. The schema that is returned
// can then be passed to `huma.Validate` to efficiently validate incoming
Expand All @@ -699,12 +723,46 @@ func SchemaFromType(r Registry, t reflect.Type) *Schema {
return s
}

func lookupTypeSchemaFactory(t reflect.Type) SchemaFactory {
if t == nil {
return nil
}

schemaFactoryMu.RLock()
factory := schemaFactories[t]
schemaFactoryMu.RUnlock()
return factory
}

func schemaFromRegisteredFactory(r Registry, t reflect.Type) *Schema {
factory := lookupTypeSchemaFactory(t)
if factory == nil {
return nil
}

s := factory(r)
if s == nil {
return nil
}

s.PrecomputeMessages()
return s
}

func schemaFromType(r Registry, t reflect.Type) *Schema {
if custom := schemaFromRegisteredFactory(r, t); custom != nil {
return custom
}

isPointer := t.Kind() == reflect.Pointer

s := Schema{}
t = deref(t)

if custom := schemaFromRegisteredFactory(r, t); custom != nil {
return custom
}

v := reflect.New(t).Interface()
if sp, ok := v.(SchemaProvider); ok {
// Special case: type provides its own schema. Do not try to generate.
Expand Down
41 changes: 41 additions & 0 deletions schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1544,3 +1544,44 @@ func TestSchemaTransformer(t *testing.T) {
updateSchema2 := huma.SchemaFromType(r, reflect.TypeOf(ExampleUpdateStruct{}))
validateSchema(updateSchema2)
}

type customSchemaInt int64

func TestRegisterTypeSchema(t *testing.T) {
huma.RegisterTypeSchema(reflect.TypeOf(customSchemaInt(0)), func(r huma.Registry) *huma.Schema {
return &huma.Schema{Type: huma.TypeString}
})

registry := huma.NewMapRegistry("#/components/schemas/", huma.DefaultSchemaNamer)

type input struct {
ID customSchemaInt `json:"id"`
Optional *customSchemaInt `json:"optional,omitempty"`
Nested struct {
Value customSchemaInt `json:"value"`
} `json:"nested"`
}

schema := huma.SchemaFromType(registry, reflect.TypeOf(input{}))
require.NotNil(t, schema)
require.Contains(t, schema.Properties, "id")
assert.Equal(t, huma.TypeString, schema.Properties["id"].Type)

require.Contains(t, schema.Properties, "optional")
assert.Equal(t, huma.TypeString, schema.Properties["optional"].Type)

nested := schema.Properties["nested"]
require.NotNil(t, nested)
nestedSchema := nested
if nested.Ref != "" {
nestedSchema = registry.SchemaFromRef(nested.Ref)
require.NotNil(t, nestedSchema)
}
assert.Equal(t, huma.TypeObject, nestedSchema.Type)
require.Contains(t, nestedSchema.Properties, "value")
assert.Equal(t, huma.TypeString, nestedSchema.Properties["value"].Type)

custom := registry.Schema(reflect.TypeOf(customSchemaInt(0)), true, "Standalone")
require.NotNil(t, custom)
assert.Equal(t, huma.TypeString, custom.Type)
}