27
27
from ._pydantic import pydantic_to_json_schema
28
28
from ._subgrammar import lexeme , subgrammar
29
29
30
+ JSONSchema = Union [bool , Mapping [str , Any ]]
30
31
31
32
def _to_compact_json (target : Any ) -> str :
32
33
# See 'Compact Encoding':
@@ -150,8 +151,8 @@ def _gen_json_string(
150
151
def _gen_json_object (
151
152
lm ,
152
153
* ,
153
- properties : Mapping [str , Any ],
154
- additional_properties : Union [ bool , Mapping [ str , Any ]] ,
154
+ properties : Mapping [str , JSONSchema ],
155
+ additional_properties : JSONSchema ,
155
156
required : Sequence [str ],
156
157
definitions : Mapping [str , Callable [[], GrammarFunction ]],
157
158
):
@@ -206,16 +207,12 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool
206
207
def _gen_json_array (
207
208
lm ,
208
209
* ,
209
- prefix_items_schema : Sequence [Mapping [ str , Any ] ],
210
- item_schema : Union [ bool , Mapping [ str , Any ]] ,
210
+ prefix_items_schema : Sequence [JSONSchema ],
211
+ item_schema : JSONSchema ,
211
212
min_items : int ,
212
213
max_items : Optional [int ],
213
214
definitions : Mapping [str , Callable [[], GrammarFunction ]],
214
215
):
215
- if item_schema is True :
216
- # True means that anything goes
217
- item_schema = {}
218
-
219
216
if len (prefix_items_schema ) < min_items and item_schema is False :
220
217
raise ValueError (
221
218
f"PrefixItems has too few elements ({ len (prefix_items_schema )} ) to"
@@ -282,7 +279,7 @@ def _gen_json_array(
282
279
def _process_anyOf (
283
280
lm ,
284
281
* ,
285
- anyof_list : Sequence [Mapping [ str , Any ] ],
282
+ anyof_list : Sequence [JSONSchema ],
286
283
definitions : Mapping [str , Callable [[], GrammarFunction ]],
287
284
):
288
285
options = [_gen_json (json_schema = item , definitions = definitions ) for item in anyof_list ]
@@ -329,9 +326,14 @@ def _gen_json_any(lm):
329
326
@guidance (stateless = True )
330
327
def _gen_json (
331
328
lm ,
332
- json_schema : Mapping [ str , Any ] ,
329
+ json_schema : JSONSchema ,
333
330
definitions : Mapping [str , Callable [[], GrammarFunction ]],
334
331
):
332
+ if json_schema is True :
333
+ json_schema = {}
334
+ elif json_schema is False :
335
+ raise ValueError ("No valid JSON can be generated from a schema of `False`" )
336
+
335
337
validate_json_node_keys (json_schema )
336
338
337
339
if Keyword .ANYOF in json_schema :
@@ -403,7 +405,7 @@ def json(
403
405
* ,
404
406
schema : Union [
405
407
None ,
406
- Mapping [ str , Any ] ,
408
+ JSONSchema ,
407
409
Type ["pydantic.BaseModel" ],
408
410
"pydantic.TypeAdapter" ,
409
411
] = None ,
@@ -457,20 +459,25 @@ def json(
457
459
If True, the generated JSON will be forced to be compact (no whitespace).
458
460
If False, output will be whitespace-flexible (i.e. decided by the model).
459
461
"""
460
- if isinstance (schema , Mapping ):
462
+ if schema is None :
463
+ # Default schema is empty, "anything goes" schema
464
+ # TODO: consider default being `{"type": "object"}`
465
+ schema = {}
466
+ elif isinstance (schema , (Mapping , bool )):
461
467
# Raises jsonschema.exceptions.SchemaError or ValueError
462
468
# if schema is not valid
463
469
jsonschema .validators .Draft202012Validator .check_schema (schema )
464
- elif schema is None :
465
- schema = {}
466
- else :
470
+ elif isinstance (schema , pydantic .TypeAdapter ) or (isinstance (schema , type ) and issubclass (schema , pydantic .BaseModel )):
467
471
schema = pydantic_to_json_schema (schema )
472
+ else :
473
+ raise TypeError (f"Unsupported schema type: { type (schema )} " )
468
474
469
475
definitions : Mapping [str , Callable [[], GrammarFunction ]] = {}
470
- for dk in DEFS_KEYS :
471
- if dk in schema :
472
- assert len (definitions ) == 0 , "Found duplicate definitions"
473
- definitions = _build_definitions (schema [dk ])
476
+ if isinstance (schema , Mapping ):
477
+ for dk in DEFS_KEYS :
478
+ if dk in schema :
479
+ assert len (definitions ) == 0 , "Found duplicate definitions"
480
+ definitions = _build_definitions (schema [dk ])
474
481
475
482
return lm + with_temperature (
476
483
subgrammar (
@@ -488,11 +495,11 @@ def json(
488
495
489
496
490
497
def _build_definitions (
491
- raw_definitions : Mapping [str , Any ]
498
+ raw_definitions : Mapping [str , JSONSchema ]
492
499
) -> Mapping [str , Callable [[], GrammarFunction ]]:
493
500
definitions : Dict [str , Callable [[], GrammarFunction ]] = {}
494
501
495
- def build_definition (json_schema : Mapping [ str , Any ] ) -> Callable [[], GrammarFunction ]:
502
+ def build_definition (json_schema : JSONSchema ) -> Callable [[], GrammarFunction ]:
496
503
@guidance (stateless = True , dedent = False , cache = True )
497
504
def closure (lm ):
498
505
return lm + _gen_json (json_schema = json_schema , definitions = definitions )
0 commit comments