diff --git a/mashumaro/jsonschema/models.py b/mashumaro/jsonschema/models.py index b194ccc..c68e08a 100644 --- a/mashumaro/jsonschema/models.py +++ b/mashumaro/jsonschema/models.py @@ -8,6 +8,7 @@ from typing_extensions import TYPE_CHECKING, Self, TypeAlias from mashumaro.config import BaseConfig +from mashumaro.core.meta.helpers import iter_all_subclasses from mashumaro.helper import pass_through from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect @@ -96,6 +97,17 @@ class JSONSchemaInstanceFormatExtension(JSONSchemaInstanceFormat): } +def _deserialize_json_schema_instance_format( + value: Any, +) -> JSONSchemaInstanceFormat: + for cls in iter_all_subclasses(JSONSchemaInstanceFormat): + try: + return cls(value) + except (ValueError, TypeError): + pass + raise ValueError(value) + + @dataclass(unsafe_hash=True) class JSONSchema(DataClassJSONMixin): # Common keywords @@ -103,9 +115,7 @@ class JSONSchema(DataClassJSONMixin): type: Optional[JSONSchemaInstanceType] = None enum: Optional[list[Any]] = None const: Optional[Any] = field(default_factory=lambda: MISSING) - format: Optional[ - Union[JSONSchemaStringFormat, JSONSchemaInstanceFormatExtension] - ] = None + format: Optional[JSONSchemaInstanceFormat] = None title: Optional[str] = None description: Optional[str] = None anyOf: Optional[List["JSONSchema"]] = None @@ -157,6 +167,9 @@ class Config(BaseConfig): int: pass_through, float: pass_through, Null: pass_through, + JSONSchemaInstanceFormat: { + "deserialize": _deserialize_json_schema_instance_format, + }, } def __pre_serialize__(self) -> Self: diff --git a/tests/test_jsonschema/test_json_schema_common.py b/tests/test_jsonschema/test_json_schema_common.py index 9b39e1c..9e2a8f5 100644 --- a/tests/test_jsonschema/test_json_schema_common.py +++ b/tests/test_jsonschema/test_json_schema_common.py @@ -1,4 +1,10 @@ +import pytest + from mashumaro.config import BaseConfig +from mashumaro.jsonschema.models import ( + JSONSchemaStringFormat, + _deserialize_json_schema_instance_format, +) from mashumaro.jsonschema.schema import Instance @@ -9,3 +15,12 @@ def test_instance_get_configs(): derived = instance.derive() assert derived.get_self_config() is instance.get_self_config() + + +def test_deserialize_json_schema_instance_format(): + assert ( + _deserialize_json_schema_instance_format("email") + is JSONSchemaStringFormat.EMAIL + ) + with pytest.raises(ValueError): + assert _deserialize_json_schema_instance_format("foobar") diff --git a/tests/test_jsonschema/test_jsonschema_generation.py b/tests/test_jsonschema/test_jsonschema_generation.py index 15b594f..9ffd4a0 100644 --- a/tests/test_jsonschema/test_jsonschema_generation.py +++ b/tests/test_jsonschema/test_jsonschema_generation.py @@ -65,14 +65,21 @@ from mashumaro.jsonschema.builder import JSONSchemaBuilder, build_json_schema from mashumaro.jsonschema.dialects import DRAFT_2020_12, OPEN_API_3_1 from mashumaro.jsonschema.models import ( + Context, JSONArraySchema, JSONObjectSchema, JSONSchema, + JSONSchemaInstanceFormat, JSONSchemaInstanceFormatExtension, JSONSchemaInstanceType, JSONSchemaStringFormat, ) -from mashumaro.jsonschema.schema import UTC_OFFSET_PATTERN, EmptyJSONSchema +from mashumaro.jsonschema.plugins import BasePlugin +from mashumaro.jsonschema.schema import ( + UTC_OFFSET_PATTERN, + EmptyJSONSchema, + Instance, +) from mashumaro.types import Discriminator, SerializationStrategy from tests.entities import ( CustomPath, @@ -1354,3 +1361,79 @@ class Main: additionalProperties=False, ) assert build_json_schema(Main) == schema + + +def test_jsonschema_with_custom_instance_format(): + class CustomJSONSchemaInstanceFormatPlugin(BasePlugin): + def get_schema( + self, + instance: Instance, + ctx: Context, + schema: Optional[JSONSchema] = None, + ) -> Optional[JSONSchema]: + for annotation in instance.annotations: + if isinstance(annotation, JSONSchemaInstanceFormat): + schema.format = annotation + return schema + + class Custom1InstanceFormat(JSONSchemaInstanceFormat): + CUSTOM1 = "custom1" + + class CustomInstanceFormatBase(JSONSchemaInstanceFormat): + pass + + class Custom2InstanceFormat(CustomInstanceFormatBase): + CUSTOM2 = "custom2" + + type1 = Annotated[str, Custom1InstanceFormat.CUSTOM1] + schema1 = build_json_schema( + type1, plugins=[CustomJSONSchemaInstanceFormatPlugin()] + ) + assert schema1.format is Custom1InstanceFormat.CUSTOM1 + assert schema1.to_dict()["format"] == "custom1" + + type2 = Annotated[int, Custom2InstanceFormat.CUSTOM2] + schema2 = build_json_schema( + type2, plugins=[CustomJSONSchemaInstanceFormatPlugin()] + ) + assert schema2.format is Custom2InstanceFormat.CUSTOM2 + assert schema2.to_dict()["format"] == "custom2" + + assert ( + JSONSchema.from_dict({"format": "custom1"}).format + is Custom1InstanceFormat.CUSTOM1 + ) + assert ( + JSONSchema.from_dict({"format": "custom2"}).format + is Custom2InstanceFormat.CUSTOM2 + ) + + @dataclass + class MyClass: + x: str + y: str + + class Config(BaseConfig): + json_schema = { + "properties": { + "x": {"type": "string", "format": "custom1"}, + "y": {"type": "string", "format": "custom2"}, + } + } + + schema3 = build_json_schema(MyClass) + assert schema3 == JSONObjectSchema( + title="MyClass", + properties={ + "x": JSONSchema( + type=JSONSchemaInstanceType.STRING, + format=Custom1InstanceFormat.CUSTOM1, + ), + "y": JSONSchema( + type=JSONSchemaInstanceType.STRING, + format=Custom2InstanceFormat.CUSTOM2, + ), + }, + required=["x", "y"], + additionalProperties=False, + )