Skip to content

Commit 30af14a

Browse files
committed
Create the json schema conversion methods
1 parent b8f8803 commit 30af14a

File tree

4 files changed

+660
-20
lines changed

4 files changed

+660
-20
lines changed

outlines/types/dsl.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import (
2323
Any,
2424
List,
25+
Literal,
2526
Optional as OptionalType,
2627
Union,
2728
get_args,
@@ -41,6 +42,11 @@
4142

4243
import outlines.types as types
4344
from outlines import grammars
45+
from outlines.types.json_schema_utils import (
46+
json_schema_dict_to_pydantic,
47+
json_schema_dict_to_typeddict,
48+
json_schema_dict_to_dataclass,
49+
)
4450
from outlines.types.utils import (
4551
get_schema_from_signature,
4652
is_int,
@@ -282,6 +288,9 @@ class JsonSchema(Term):
282288
genSON schema builder.
283289
284290
"""
291+
schema: str
292+
whitespace_pattern: OptionalType[str]
293+
285294
def __init__(
286295
self,
287296
schema: Union[
@@ -329,6 +338,109 @@ def __init__(
329338
def __post_init__(self):
330339
jsonschema.Draft7Validator.check_schema(json.loads(self.schema))
331340

341+
@classmethod
342+
def is_json_schema(cls, obj: Any) -> bool:
343+
"""Check if the object provided is a JSON schema type.
344+
345+
Parameters
346+
----------
347+
obj: Any
348+
The object to check
349+
350+
Returns
351+
-------
352+
bool
353+
True if the object is a JSON schema type, False otherwise
354+
355+
"""
356+
return (
357+
isinstance(obj, cls)
358+
or is_pydantic_model(obj)
359+
or is_typed_dict(obj)
360+
or is_dataclass(obj)
361+
or is_genson_schema_builder(obj)
362+
)
363+
364+
@classmethod
365+
def convert_to(
366+
cls,
367+
schema: Union[
368+
"JsonSchema",
369+
type[BaseModel],
370+
_TypedDictMeta,
371+
type,
372+
SchemaBuilder,
373+
],
374+
target_types: List[Literal[
375+
"str",
376+
"dict",
377+
"pydantic",
378+
"typeddict",
379+
"dataclass",
380+
"genson",
381+
]],
382+
) -> Union[str, dict, type[BaseModel], _TypedDictMeta, type, SchemaBuilder]:
383+
"""Convert a JSON schema type to a different JSON schema type.
384+
385+
If the schema provided is already of a type in the target_types, return
386+
it unchanged.
387+
388+
Parameters
389+
----------
390+
schema: Union[JsonSchema, type[BaseModel], _TypedDictMeta, type, SchemaBuilder]
391+
The schema to convert
392+
target_types: List[Literal["str", "dict", "pydantic", "typeddict", "dataclass", "genson"]]
393+
The target types to convert to
394+
395+
"""
396+
# If the schema provided is already of a type in the target_types,
397+
# just return it
398+
if isinstance(schema, cls):
399+
if "str" in target_types:
400+
return schema.schema
401+
elif "dict" in target_types:
402+
return json.loads(schema.schema)
403+
elif is_pydantic_model(schema) and "pydantic" in target_types:
404+
return schema
405+
elif is_typed_dict(schema) and "typeddict" in target_types:
406+
return schema
407+
elif is_dataclass(schema) and "dataclass" in target_types:
408+
return schema
409+
elif is_genson_schema_builder(schema) and "genson" in target_types:
410+
return schema
411+
412+
# Convert the schema to a JSON schema string/dict
413+
if isinstance(schema, cls):
414+
schema_str = schema.schema
415+
else:
416+
schema_str = cls(schema).schema
417+
schema_dict = json.loads(schema_str)
418+
419+
for target_type in target_types:
420+
try:
421+
# Convert the JSON schema string to the target type
422+
if target_type == "str":
423+
return schema_str
424+
elif target_type == "dict":
425+
return schema_dict
426+
elif target_type == "pydantic":
427+
return json_schema_dict_to_pydantic(schema_dict)
428+
elif target_type == "typeddict":
429+
return json_schema_dict_to_typeddict(schema_dict)
430+
elif target_type == "dataclass":
431+
return json_schema_dict_to_dataclass(schema_dict)
432+
# No conversion available for genson
433+
except Exception as e: # pragma: no cover
434+
warnings.warn(
435+
f"Cannot convert schema type {type(schema)} to {target_type}: {e}"
436+
)
437+
continue
438+
439+
raise ValueError(
440+
f"Cannot convert schema type {type(schema)} to any of the target "
441+
f"types {target_types}"
442+
)
443+
332444
def _display_node(self) -> str:
333445
return f"JsonSchema('{self.schema}')"
334446

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,175 @@
1-
"""Utilities for handling JSON schema compatibility."""
1+
"""Convert JSON Schema dicts to Python types."""
2+
3+
import sys
4+
from dataclasses import dataclass, field
5+
from typing import Any, Dict, List, Literal, Optional
6+
7+
from pydantic import BaseModel, create_model
8+
9+
if sys.version_info >= (3, 12): # pragma: no cover
10+
from typing import _TypedDictMeta, TypedDict # type: ignore
11+
else: # pragma: no cover
12+
from typing_extensions import _TypedDictMeta, TypedDict # type: ignore
13+
14+
15+
def schema_type_to_python(
16+
schema: dict,
17+
caller_target_type: Literal["pydantic", "typeddict", "dataclass"]
18+
) -> Any:
19+
"""Get a Python type from a JSON Schema dict.
20+
21+
Parameters
22+
----------
23+
schema: dict
24+
The JSON Schema dict to convert to a Python type
25+
caller_target_type: Literal["pydantic", "typeddict", "dataclass"]
26+
The type of the caller
27+
28+
Returns
29+
-------
30+
Any
31+
The Python type
32+
33+
"""
34+
if "enum" in schema:
35+
values = schema["enum"]
36+
return Literal[tuple(values)]
37+
38+
t = schema.get("type")
39+
40+
if t == "string":
41+
return str
42+
elif t == "integer":
43+
return int
44+
elif t == "number":
45+
return float
46+
elif t == "boolean":
47+
return bool
48+
elif t == "array":
49+
items = schema.get("items", {})
50+
if items:
51+
item_type = schema_type_to_python(items, caller_target_type)
52+
else:
53+
item_type = Any
54+
return List[item_type] # type: ignore
55+
elif t == "object":
56+
name = schema.get("title")
57+
if caller_target_type == "pydantic":
58+
return json_schema_dict_to_pydantic(schema, name)
59+
elif caller_target_type == "typeddict":
60+
return json_schema_dict_to_typeddict(schema, name)
61+
elif caller_target_type == "dataclass":
62+
return json_schema_dict_to_dataclass(schema, name)
63+
64+
return Any
65+
66+
67+
def json_schema_dict_to_typeddict(
68+
schema: dict,
69+
name: Optional[str] = None
70+
) -> _TypedDictMeta:
71+
"""Convert a JSON Schema dict into a TypedDict class.
72+
73+
Parameters
74+
----------
75+
schema: dict
76+
The JSON Schema dict to convert to a TypedDict
77+
name: Optional[str]
78+
The name of the TypedDict
79+
80+
Returns
81+
-------
82+
_TypedDictMeta
83+
The TypedDict class
84+
85+
"""
86+
required = set(schema.get("required", []))
87+
properties = schema.get("properties", {})
88+
89+
annotations: Dict[str, Any] = {}
90+
91+
for property, details in properties.items():
92+
typ = schema_type_to_python(details, "typeddict")
93+
if property not in required:
94+
typ = Optional[typ]
95+
annotations[property] = typ
96+
97+
return TypedDict(name or "AnonymousTypedDict", annotations) # type: ignore
98+
99+
100+
def json_schema_dict_to_pydantic(
101+
schema: dict,
102+
name: Optional[str] = None
103+
) -> type[BaseModel]:
104+
"""Convert a JSON Schema dict into a Pydantic BaseModel class.
105+
106+
Parameters
107+
----------
108+
schema: dict
109+
The JSON Schema dict to convert to a Pydantic BaseModel
110+
name: Optional[str]
111+
The name of the Pydantic BaseModel
112+
113+
Returns
114+
-------
115+
type[BaseModel]
116+
The Pydantic BaseModel class
117+
118+
"""
119+
required = set(schema.get("required", []))
120+
properties = schema.get("properties", {})
121+
122+
field_definitions: Dict[str, Any] = {}
123+
124+
for property, details in properties.items():
125+
typ = schema_type_to_python(details, "pydantic")
126+
if property not in required:
127+
field_definitions[property] = (Optional[typ], None)
128+
else:
129+
field_definitions[property] = (typ, ...)
130+
131+
return create_model(name or "AnonymousPydanticModel", **field_definitions)
132+
133+
134+
def json_schema_dict_to_dataclass(
135+
schema: dict,
136+
name: Optional[str] = None
137+
) -> type:
138+
"""Convert a JSON Schema dict into a dataclass.
139+
140+
Parameters
141+
----------
142+
schema: dict
143+
The JSON Schema dict to convert to a dataclass
144+
name: Optional[str]
145+
The name of the dataclass
146+
147+
Returns
148+
-------
149+
type
150+
The dataclass
151+
152+
"""
153+
required = set(schema.get("required", []))
154+
properties = schema.get("properties", {})
155+
156+
annotations: Dict[str, Any] = {}
157+
defaults: Dict[str, Any] = {}
158+
159+
for property, details in properties.items():
160+
typ = schema_type_to_python(details, "dataclass")
161+
annotations[property] = typ
162+
163+
if property not in required:
164+
defaults[property] = None
165+
166+
class_dict = {
167+
'__annotations__': annotations,
168+
'__module__': __name__,
169+
}
170+
171+
for property, default_val in defaults.items():
172+
class_dict[property] = field(default=default_val)
173+
174+
cls = type(name or "AnonymousDataclass", (), class_dict)
175+
return dataclass(cls)

0 commit comments

Comments
 (0)