Skip to content

Commit 35591d8

Browse files
authored
Support boolean JSON schemas (#1015)
We were previously supporting `True`/`False` schemas only when nested in certain places like `items` and `additionalProperties`. This expands our coverage to handle top-level boolean schemas as well. Note that we will now raise a `ValueError` for `False` schemas, simply because there is nothing we can generate in this case. May need to revisit this (see #1018)
1 parent 418fc03 commit 35591d8

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

guidance/library/_json.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ._pydantic import pydantic_to_json_schema
2828
from ._subgrammar import lexeme, subgrammar
2929

30+
JSONSchema = Union[bool, Mapping[str, Any]]
3031

3132
def _to_compact_json(target: Any) -> str:
3233
# See 'Compact Encoding':
@@ -150,8 +151,8 @@ def _gen_json_string(
150151
def _gen_json_object(
151152
lm,
152153
*,
153-
properties: Mapping[str, Any],
154-
additional_properties: Union[bool, Mapping[str, Any]],
154+
properties: Mapping[str, JSONSchema],
155+
additional_properties: JSONSchema,
155156
required: Sequence[str],
156157
definitions: Mapping[str, Callable[[], GrammarFunction]],
157158
):
@@ -206,16 +207,12 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool
206207
def _gen_json_array(
207208
lm,
208209
*,
209-
prefix_items_schema: Sequence[Mapping[str, Any]],
210-
item_schema: Union[bool, Mapping[str, Any]],
210+
prefix_items_schema: Sequence[JSONSchema],
211+
item_schema: JSONSchema,
211212
min_items: int,
212213
max_items: Optional[int],
213214
definitions: Mapping[str, Callable[[], GrammarFunction]],
214215
):
215-
if item_schema is True:
216-
# True means that anything goes
217-
item_schema = {}
218-
219216
if len(prefix_items_schema) < min_items and item_schema is False:
220217
raise ValueError(
221218
f"PrefixItems has too few elements ({len(prefix_items_schema)}) to"
@@ -282,7 +279,7 @@ def _gen_json_array(
282279
def _process_anyOf(
283280
lm,
284281
*,
285-
anyof_list: Sequence[Mapping[str, Any]],
282+
anyof_list: Sequence[JSONSchema],
286283
definitions: Mapping[str, Callable[[], GrammarFunction]],
287284
):
288285
options = [_gen_json(json_schema=item, definitions=definitions) for item in anyof_list]
@@ -329,9 +326,14 @@ def _gen_json_any(lm):
329326
@guidance(stateless=True)
330327
def _gen_json(
331328
lm,
332-
json_schema: Mapping[str, Any],
329+
json_schema: JSONSchema,
333330
definitions: Mapping[str, Callable[[], GrammarFunction]],
334331
):
332+
if json_schema is True:
333+
json_schema = {}
334+
elif json_schema is False:
335+
raise ValueError("No valid JSON can be generated from a schema of `False`")
336+
335337
validate_json_node_keys(json_schema)
336338

337339
if Keyword.ANYOF in json_schema:
@@ -403,7 +405,7 @@ def json(
403405
*,
404406
schema: Union[
405407
None,
406-
Mapping[str, Any],
408+
JSONSchema,
407409
Type["pydantic.BaseModel"],
408410
"pydantic.TypeAdapter",
409411
] = None,
@@ -457,20 +459,25 @@ def json(
457459
If True, the generated JSON will be forced to be compact (no whitespace).
458460
If False, output will be whitespace-flexible (i.e. decided by the model).
459461
"""
460-
if isinstance(schema, Mapping):
462+
if schema is None:
463+
# Default schema is empty, "anything goes" schema
464+
# TODO: consider default being `{"type": "object"}`
465+
schema = {}
466+
elif isinstance(schema, (Mapping, bool)):
461467
# Raises jsonschema.exceptions.SchemaError or ValueError
462468
# if schema is not valid
463469
jsonschema.validators.Draft202012Validator.check_schema(schema)
464-
elif schema is None:
465-
schema = {}
466-
else:
470+
elif isinstance(schema, pydantic.TypeAdapter) or (isinstance(schema, type) and issubclass(schema, pydantic.BaseModel)):
467471
schema = pydantic_to_json_schema(schema)
472+
else:
473+
raise TypeError(f"Unsupported schema type: {type(schema)}")
468474

469475
definitions: Mapping[str, Callable[[], GrammarFunction]] = {}
470-
for dk in DEFS_KEYS:
471-
if dk in schema:
472-
assert len(definitions) == 0, "Found duplicate definitions"
473-
definitions = _build_definitions(schema[dk])
476+
if isinstance(schema, Mapping):
477+
for dk in DEFS_KEYS:
478+
if dk in schema:
479+
assert len(definitions) == 0, "Found duplicate definitions"
480+
definitions = _build_definitions(schema[dk])
474481

475482
return lm + with_temperature(
476483
subgrammar(
@@ -488,11 +495,11 @@ def json(
488495

489496

490497
def _build_definitions(
491-
raw_definitions: Mapping[str, Any]
498+
raw_definitions: Mapping[str, JSONSchema]
492499
) -> Mapping[str, Callable[[], GrammarFunction]]:
493500
definitions: Dict[str, Callable[[], GrammarFunction]] = {}
494501

495-
def build_definition(json_schema: Mapping[str, Any]) -> Callable[[], GrammarFunction]:
502+
def build_definition(json_schema: JSONSchema) -> Callable[[], GrammarFunction]:
496503
@guidance(stateless=True, dedent=False, cache=True)
497504
def closure(lm):
498505
return lm + _gen_json(json_schema=json_schema, definitions=definitions)

tests/unit/library/test_json.py

+33
Original file line numberDiff line numberDiff line change
@@ -2218,3 +2218,36 @@ def test_all_required_properties_doesnt_blow_up(self, num_properties):
22182218
HITS_MAGIC_NUMBER = 1
22192219
expected_hits = 0
22202220
assert cache_info.hits <= expected_hits + HITS_MAGIC_NUMBER
2221+
2222+
class TestBooleanSchema:
2223+
@pytest.mark.parametrize(
2224+
"target_obj",
2225+
[
2226+
123,
2227+
"hello",
2228+
[1, 2, 3],
2229+
{"a": 1},
2230+
None,
2231+
[{"a": 1}],
2232+
{"a": [1, 2, 3]},
2233+
{"a": {"b": 1}},
2234+
False,
2235+
True
2236+
],
2237+
)
2238+
def test_true_schema(self, target_obj):
2239+
# should be the same as an empty schema
2240+
schema_obj = True
2241+
generate_and_check(target_obj, schema_obj)
2242+
2243+
@pytest.mark.parametrize(
2244+
"schema_obj",
2245+
[
2246+
False,
2247+
{"type": "object", "properties": {"a": False}, "required": ["a"]},
2248+
]
2249+
)
2250+
def test_false_schema(self, schema_obj):
2251+
with pytest.raises(ValueError) as ve:
2252+
gen_json(schema=schema_obj)
2253+
assert ve.value.args[0] == "No valid JSON can be generated from a schema of `False`"

0 commit comments

Comments
 (0)