Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache computation of JSON nodes to improve performance #995

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
162 changes: 94 additions & 68 deletions guidance/library/_json.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from json import dumps as json_dumps
from enum import Enum
from frozendict import frozendict, deepfreeze
from functools import cache
from typing import (
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Union,
Type,
TYPE_CHECKING,
Expand All @@ -29,6 +29,7 @@
from ._subgrammar import lexeme, subgrammar

JSONSchema = Union[bool, Mapping[str, Any]]
FrozenJSONSchema = Union[bool, frozendict[str, Any]]

def _to_compact_json(target: Any) -> str:
# See 'Compact Encoding':
Expand Down Expand Up @@ -392,12 +393,12 @@ def validate_json_node_keys(node: Mapping[str, Any]):
)


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If adding cache here gives a speed up, will this result in every generated integer being the same?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll certainly double check before marking this as non-WIP, but I believe it's just caching the return value of this function, which is a GrammarFunction, not the actual string value. In other words, I believe the answer to your question is no

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related question, assuming that the return value is OK: does this get cross-referenced properly when the grammar is generated? Looking at some of the serialisation code, I expect that it does, but would be good to make sure that it doesn't get expanded inline multiple times.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache=True will ensure that subsequent calls with the same arguments will always return the exact same object (i.e. id(o1) == id(o2)).

The serialization code maintains its own cache for handling "pointers" (just integer indices) to objects in the serialized grammar, as you can see here. GrammarFunctions are hashable and just hash to the id of the object, so this code should correctly respect the object caching we do in the decorator.

def _gen_json_int(lm):
return lm + lexeme(r"-?(?:0|[1-9][0-9]*)", contextual=True)


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json_number(lm):
return lm + select([
_gen_json_int(),
Expand All @@ -406,7 +407,7 @@ def _gen_json_number(lm):
])


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json_string(
lm,
min_length: int = 0,
Expand Down Expand Up @@ -438,14 +439,14 @@ def _gen_json_string(
return lm + lexeme(regex, contextual=True, json_string=True)


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json_object(
lm,
*,
properties: Mapping[str, JSONSchema],
additional_properties: JSONSchema,
required: Sequence[str],
definitions: Mapping[str, Callable[[], GrammarFunction]],
properties: frozendict[str, FrozenJSONSchema],
additional_properties: FrozenJSONSchema,
required: frozenset[str],
definitions: frozendict[str, Callable[[], GrammarFunction]],
):
# "required" keys will be validated against "properties" if they're present, otherwise against "additionalProperties".
# If "additionalProperties" is False, then required keys must be in "properties".
Expand Down Expand Up @@ -502,15 +503,15 @@ def _gen_list(lm, *, elements: tuple[GrammarFunction, ...], required: tuple[bool
])


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json_array(
lm,
*,
prefix_items_schema: Sequence[JSONSchema],
item_schema: JSONSchema,
prefix_items_schema: tuple[FrozenJSONSchema, ...],
item_schema: FrozenJSONSchema,
min_items: int,
max_items: Optional[int],
definitions: Mapping[str, Callable[[], GrammarFunction]],
definitions: frozendict[str, Callable[[], GrammarFunction]],
):
if len(prefix_items_schema) < min_items and item_schema is False:
raise ValueError(
Expand Down Expand Up @@ -574,62 +575,93 @@ def _gen_json_array(
return lm


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _process_anyOf(
lm,
*,
anyof_list: Sequence[JSONSchema],
definitions: Mapping[str, Callable[[], GrammarFunction]],
anyof_list: tuple[FrozenJSONSchema, ...],
definitions: frozendict[str, Callable[[], GrammarFunction]],
):
options = [_gen_json(json_schema=item, definitions=definitions) for item in anyof_list]
return lm + select(options)

@guidance(stateless=True, cache=True)
def _process_allOf(
lm,
*,
allof_list: tuple[FrozenJSONSchema, ...],
definitions: frozendict[str, Callable[[], GrammarFunction]],
):
if len(allof_list) != 1:
raise ValueError("Only support allOf with exactly one item")
return lm + _gen_json(allof_list[0], definitions=definitions)

@guidance(stateless=True)
def _process_enum(lm, *, options: Sequence[Mapping[str, Any]]):
@guidance(stateless=True, cache=True)
def _process_oneOf(
lm,
*,
oneof_list: tuple[FrozenJSONSchema, ...],
definitions: frozendict[str, Callable[[], GrammarFunction]]
):
if len(oneof_list) == 1:
return lm + _gen_json(oneof_list[0], definitions)
warnings.warn("oneOf not fully supported, falling back to anyOf. This may cause validation errors in some cases.")
return lm + _process_anyOf(anyof_list=oneof_list, definitions=definitions)

@guidance(stateless=True, cache=True)
def _process_const(
lm,
*,
value: Any,
):
# TODO: can we support a whitespace-flexible version of this?
return lm + _to_compact_json(value)

@guidance(stateless=True, cache=True)
def _process_enum(lm, *, options: tuple[Any, ...]):
# TODO: can we support a whitespace-flexible version of this?
all_opts = []
for opt in options:
all_opts.append(_to_compact_json(opt))
return lm + select(options=all_opts)


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json_any(lm):
return lm + select(
[
_gen_json(json_schema={"type": "null"}, definitions={}),
_gen_json(json_schema={"type": "boolean"}, definitions={}),
_gen_json(json_schema={"type": "integer"}, definitions={}),
_gen_json(json_schema={"type": "number"}, definitions={}),
_gen_json(json_schema={"type": "string"}, definitions={}),
_gen_json(json_schema=frozendict({"type": "null"}), definitions=frozendict()),
_gen_json(json_schema=frozendict({"type": "boolean"}), definitions=frozendict()),
_gen_json(json_schema=frozendict({"type": "integer"}), definitions=frozendict()),
_gen_json(json_schema=frozendict({"type": "number"}), definitions=frozendict()),
_gen_json(json_schema=frozendict({"type": "string"}), definitions=frozendict()),
# Recursive cases
_gen_json(
json_schema={
json_schema=frozendict({
"type": "array",
"items": True,
},
definitions={},
}),
definitions=frozendict(),
),
_gen_json(
json_schema={
json_schema=frozendict({
"type": "object",
"additionalProperties": True,
},
definitions={},
}),
definitions=frozendict(),
),
]
)


@guidance(stateless=True)
@guidance(stateless=True, cache=True)
def _gen_json(
lm,
json_schema: JSONSchema,
definitions: Mapping[str, Callable[[], GrammarFunction]],
json_schema: FrozenJSONSchema,
definitions: frozendict[str, Callable[[], GrammarFunction]],
):
if json_schema is True:
json_schema = {}
json_schema = frozendict()
elif json_schema is False:
raise ValueError("No valid JSON can be generated from a schema of `False`")

Expand All @@ -642,34 +674,26 @@ def _gen_json(
return lm + _process_anyOf(anyof_list=json_schema[Keyword.ANYOF], definitions=definitions)

if Keyword.ALLOF in json_schema:
allof_list = json_schema[Keyword.ALLOF]
if len(allof_list) != 1:
raise ValueError("Only support allOf with exactly one item")
return lm + _gen_json(allof_list[0], definitions)
return lm + _process_allOf(allof_list=json_schema[Keyword.ALLOF], definitions=definitions)

if Keyword.ONEOF in json_schema:
oneof_list = json_schema[Keyword.ONEOF]
if len(oneof_list) == 1:
return lm + _gen_json(oneof_list[0], definitions)
warnings.warn("oneOf not fully supported, falling back to anyOf. This may cause validation errors in some cases.")
return lm + _process_anyOf(anyof_list=oneof_list, definitions=definitions)
return lm + _process_oneOf(oneof_list=json_schema[Keyword.ONEOF], definitions=definitions)

if Keyword.REF in json_schema:
return lm + _get_definition(reference=json_schema[Keyword.REF], definitions=definitions)

if Keyword.CONST in json_schema:
# TODO: can we support a whitespace-flexible version of this?
return lm + _to_compact_json(json_schema[Keyword.CONST])
return lm + _process_const(value=json_schema[Keyword.CONST])

if Keyword.ENUM in json_schema:
return lm + _process_enum(options=json_schema[Keyword.ENUM])

if Keyword.TYPE in json_schema:
target_types = cast(Union[str, Sequence[str]], json_schema[Keyword.TYPE])
target_types = cast(Union[str, tuple[str, ...]], json_schema[Keyword.TYPE])
if isinstance(target_types, str):
target_types = [target_types]
target_types = (target_types,)
else:
target_types = list(JSONType)
target_types = tuple(JSONType)

options: list[Union[str, GrammarFunction]] = []
option: Union[str, GrammarFunction]
Expand All @@ -691,17 +715,17 @@ def _gen_json(
)
elif target_type == JSONType.ARRAY:
option = _gen_json_array(
prefix_items_schema=json_schema.get(ArrayKeywords.PREFIX_ITEMS, []),
prefix_items_schema=json_schema.get(ArrayKeywords.PREFIX_ITEMS, ()),
item_schema=json_schema.get(ArrayKeywords.ITEMS, True),
min_items=json_schema.get(ArrayKeywords.MIN_ITEMS, 0),
max_items=json_schema.get(ArrayKeywords.MAX_ITEMS, None),
definitions=definitions,
)
elif target_type == JSONType.OBJECT:
option = _gen_json_object(
properties=json_schema.get(ObjectKeywords.PROPERTIES, {}),
properties=json_schema.get(ObjectKeywords.PROPERTIES, frozendict()),
additional_properties=json_schema.get(ObjectKeywords.ADDITIONAL_PROPERTIES, True),
required=json_schema.get(ObjectKeywords.REQUIRED, set()),
required=json_schema.get(ObjectKeywords.REQUIRED, frozenset()),
definitions=definitions,
)
else:
Expand Down Expand Up @@ -785,17 +809,22 @@ def json(
else:
raise TypeError(f"Unsupported schema type: {type(schema)}")

definitions: Mapping[str, Callable[[], GrammarFunction]] = {}
definitions: frozendict[str, Callable[[], GrammarFunction]] = frozendict()
frozen_schema: FrozenJSONSchema
if isinstance(schema, Mapping):
# Freeze the schema to make it immutable and hashable
frozen_schema = cast(frozendict, deepfreeze(schema))
for dk in DEFS_KEYS:
if dk in schema:
if dk in frozen_schema:
assert len(definitions) == 0, "Found duplicate definitions"
definitions = _build_definitions(schema[dk])
definitions = _build_definitions(frozen_schema[dk])
else:
frozen_schema = cast(bool, schema)

return lm + with_temperature(
subgrammar(
name,
body=_gen_json(json_schema=schema, definitions=definitions),
body=_gen_json(json_schema=frozen_schema, definitions=definitions),
skip_regex=(
None if compact
else r"[\x20\x0A\x0D\x09]+"
Expand All @@ -806,30 +835,27 @@ def json(
temperature=temperature,
)


@cache
def _build_definitions(
raw_definitions: Mapping[str, JSONSchema]
) -> Mapping[str, Callable[[], GrammarFunction]]:
definitions: Dict[str, Callable[[], GrammarFunction]] = {}
raw_definitions: frozendict[str, FrozenJSONSchema]
) -> frozendict[str, Callable[[], GrammarFunction]]:
definitions: frozendict[str, Callable[[], GrammarFunction]]

def build_definition(json_schema: JSONSchema) -> Callable[[], GrammarFunction]:
def build_definition(json_schema: FrozenJSONSchema) -> Callable[[], GrammarFunction]:
@guidance(stateless=True, dedent=False, cache=True)
def closure(lm):
return lm + _gen_json(json_schema=json_schema, definitions=definitions)

return closure

definitions = {ref: build_definition(schema) for ref, schema in raw_definitions.items()}
definitions = frozendict({ref: build_definition(schema) for ref, schema in raw_definitions.items()})
return definitions


@guidance(stateless=True)
def _get_definition(
lm,
*,
reference: str,
definitions: Mapping[str, Callable[[], GrammarFunction]],
):
definitions: frozendict[str, Callable[[], GrammarFunction]],
) -> GrammarFunction:
assert definitions is not None
target_definition = None
for dk in DEFS_KEYS:
Expand All @@ -839,4 +865,4 @@ def _get_definition(
target_definition = definitions[target_name]

assert target_definition is not None
return lm + target_definition()
return target_definition()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

install_requires = [
"diskcache",
"frozendict",
"numpy",
"ordered_set",
"platformdirs",
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/library/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,10 @@ def test_oneOf_compound(self, target_obj):
validate(instance=target_obj, schema=schema_obj)

# The actual check; we expect a warning here because oneOf is not fully supported
from guidance.library._json import _process_oneOf, _gen_json
# Reset the caches to ensure we get the warning
_gen_json.__wrapped__.cache_clear()
_process_oneOf.__wrapped__.cache_clear()
with pytest.warns() as record:
generate_and_check(target_obj, schema_obj)
assert len(record) == 1
Expand Down
Loading