Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC recursive schemas #61

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 96 additions & 63 deletions src/hypothesis_jsonschema/_canonicalise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import math
import re
from copy import deepcopy
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
from typing import Any, Dict, List, NoReturn, Optional, Set, Tuple, Union
from urllib.parse import urljoin

import jsonschema
from hypothesis.errors import InvalidArgument
Expand Down Expand Up @@ -68,6 +69,13 @@ def next_down(val: float) -> float:
return out


class LocalResolver(jsonschema.RefResolver):
def resolve_remote(self, uri: str) -> NoReturn:
raise HypothesisRefResolutionError(
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
)


def _get_validator_class(schema: Schema) -> JSONSchemaValidator:
try:
validator = jsonschema.validators.validator_for(schema)
Expand All @@ -78,9 +86,9 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator:
return validator


def make_validator(schema: Schema) -> JSONSchemaValidator:
def make_validator(schema: Schema, resolver: LocalResolver) -> JSONSchemaValidator:
validator = _get_validator_class(schema)
return validator(schema)
return validator(schema, resolver=resolver)


class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError):
Expand Down Expand Up @@ -202,7 +210,9 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]:
return lower, upper


def canonicalish(schema: JSONType) -> Dict[str, Any]:
def canonicalish(
schema: JSONType, resolver: Optional[LocalResolver] = None
) -> Dict[str, Any]:
"""Convert a schema into a more-canonical form.

This is obviously incomplete, but improves best-effort recognition of
Expand All @@ -224,12 +234,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
"but expected a dict."
)

if resolver is None:
resolver = LocalResolver.from_schema(deepcopy(schema))

if "const" in schema:
if not make_validator(schema).is_valid(schema["const"]):
if not make_validator(schema, resolver=resolver).is_valid(schema["const"]):
return FALSEY
return {"const": schema["const"]}
if "enum" in schema:
validator = make_validator(schema)
validator = make_validator(schema, resolver=resolver)
enum_ = sorted(
(v for v in schema["enum"] if validator.is_valid(v)), key=sort_key
)
Expand All @@ -253,15 +266,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
# Recurse into the value of each keyword with a schema (or list of them) as a value
for key in SCHEMA_KEYS:
if isinstance(schema.get(key), list):
schema[key] = [canonicalish(v) for v in schema[key]]
schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]]
elif isinstance(schema.get(key), (bool, dict)):
schema[key] = canonicalish(schema[key])
schema[key] = canonicalish(schema[key], resolver=resolver)
else:
assert key not in schema, (key, schema[key])
for key in SCHEMA_OBJECT_KEYS:
if key in schema:
schema[key] = {
k: v if isinstance(v, list) else canonicalish(v)
k: v if isinstance(v, list) else canonicalish(v, resolver=resolver)
for k, v in schema[key].items()
}

Expand Down Expand Up @@ -307,7 +320,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:

if "array" in type_ and "contains" in schema:
if isinstance(schema.get("items"), dict):
contains_items = merged([schema["contains"], schema["items"]])
contains_items = merged(
[schema["contains"], schema["items"]], resolver=resolver
)
if contains_items is not None:
schema["contains"] = contains_items

Expand Down Expand Up @@ -432,7 +447,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
type_.remove("object")
else:
propnames = schema.get("propertyNames", {})
validator = make_validator(propnames)
validator = make_validator(propnames, resolver=resolver)
if not all(validator.is_valid(name) for name in schema["required"]):
type_.remove("object")

Expand Down Expand Up @@ -461,9 +476,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
type_.remove(t)
if t not in ("integer", "number"):
not_["type"].remove(t)
not_ = canonicalish(not_)
not_ = canonicalish(not_, resolver=resolver)

m = merged([not_, {**schema, "type": type_}])
m = merged([not_, {**schema, "type": type_}], resolver=resolver)
if m is not None:
not_ = m
if not_ != FALSEY:
Expand Down Expand Up @@ -525,7 +540,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
else:
tmp = schema.copy()
ao = tmp.pop("allOf")
out = merged([tmp] + ao)
out = merged([tmp] + ao, resolver=resolver)
if isinstance(out, dict): # pragma: no branch
schema = out
# TODO: this assertion is soley because mypy 0.750 doesn't know
Expand All @@ -537,7 +552,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
one_of = sorted(one_of, key=encode_canonical_json)
one_of = [s for s in one_of if s != FALSEY]
if len(one_of) == 1:
m = merged([schema, one_of[0]])
m = merged([schema, one_of[0]], resolver=resolver)
if m is not None: # pragma: no branch
return m
if (not one_of) or one_of.count(TRUTHY) > 1:
Expand All @@ -552,23 +567,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]:
FALSEY = canonicalish(False)


class LocalResolver(jsonschema.RefResolver):
def resolve_remote(self, uri: str) -> NoReturn:
raise HypothesisRefResolutionError(
f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})"
)


def resolve_all_refs(
schema: Union[bool, Schema], *, resolver: LocalResolver = None
schema: Union[bool, Schema],
*,
resolver: LocalResolver = None,
seen_map: Dict[str, Set[str]] = None,
) -> Schema:
"""
Resolve all references in the given schema.

This handles nested definitions, but not recursive definitions.
The latter require special handling to convert to strategies and are much
less common, so we just ignore them (and error out) for now.
"""
"""Resolve all non-recursive references in the given schema."""
if seen_map is None:
seen_map = {}
if isinstance(schema, bool):
return canonicalish(schema)
assert isinstance(schema, dict), schema
Expand All @@ -579,43 +586,61 @@ def resolve_all_refs(
f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver"
)

def is_recursive(reference: str) -> bool:
full_ref = urljoin(resolver.base_uri, reference) # type: ignore
return reference == "#" or reference in resolver._scopes_stack or full_ref in resolver._scopes_stack # type: ignore

# To avoid infinite recursion, we skip all recursive definitions, and such references will be processed later
# A definition is recursive if it contains a reference to itself or one of its ancestors.
if "$ref" in schema:
s = dict(schema)
ref = s.pop("$ref")
with resolver.resolving(ref) as got:
if s == {}:
return resolve_all_refs(got, resolver=resolver)
m = merged([s, got])
if m is None: # pragma: no cover
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
raise HypothesisRefResolutionError(msg)
return resolve_all_refs(m, resolver=resolver)
assert "$ref" not in schema
path = "-".join(resolver._scopes_stack)
seen_paths = seen_map.setdefault(path, set())
if schema["$ref"] not in seen_paths and not is_recursive(schema["$ref"]): # type: ignore
seen_paths.add(schema["$ref"]) # type: ignore
s = dict(schema)
ref = s.pop("$ref")
with resolver.resolving(ref) as got:
if s == {}:
return resolve_all_refs(got, resolver=resolver, seen_map=seen_map)
m = merged([s, got])
if m is None: # pragma: no cover
msg = f"$ref:{ref!r} had incompatible base schema {s!r}"
raise HypothesisRefResolutionError(msg)

return resolve_all_refs(m, resolver=resolver, seen_map=seen_map)

for key in SCHEMA_KEYS:
val = schema.get(key, False)
if isinstance(val, list):
schema[key] = [
resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v
resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map)
if isinstance(v, dict)
else v
for v in val
]
elif isinstance(val, dict):
schema[key] = resolve_all_refs(val, resolver=resolver)
schema[key] = resolve_all_refs(
deepcopy(val), resolver=resolver, seen_map=seen_map
)
else:
assert isinstance(val, bool)
for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas
if key in schema:
subschema = schema[key]
assert isinstance(subschema, dict)
schema[key] = {
k: resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v
k: resolve_all_refs(deepcopy(v), resolver=resolver, seen_map=seen_map)
if isinstance(v, dict)
else v
for k, v in subschema.items()
}
assert isinstance(schema, dict)
return schema


def merged(schemas: List[Any]) -> Optional[Schema]:
def merged(
schemas: List[Any], resolver: Optional[LocalResolver] = None
) -> Optional[Schema]:
"""Merge *n* schemas into a single schema, or None if result is invalid.

Takes the logical intersection, so any object that validates against the returned
Expand All @@ -628,7 +653,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
It's currently also used for keys that could be merged but aren't yet.
"""
assert schemas, "internal error: must pass at least one schema to merge"
schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances)
schemas = sorted(
(canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances
)
if any(s == FALSEY for s in schemas):
return FALSEY
out = schemas[0]
Expand All @@ -637,11 +664,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
continue
# If we have a const or enum, this is fairly easy by filtering:
if "const" in out:
if make_validator(s).is_valid(out["const"]):
if make_validator(s, resolver=resolver).is_valid(out["const"]):
continue
return FALSEY
if "enum" in out:
validator = make_validator(s)
validator = make_validator(s, resolver=resolver)
enum_ = [v for v in out["enum"] if validator.is_valid(v)]
if not enum_:
return FALSEY
Expand Down Expand Up @@ -692,36 +719,41 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
else:
out_combined = merged(
[s for p, s in out_pat.items() if re.search(p, prop_name)]
or [out_add]
or [out_add],
resolver=resolver,
)
if prop_name in s_props:
s_combined = s_props[prop_name]
else:
s_combined = merged(
[s for p, s in s_pat.items() if re.search(p, prop_name)]
or [s_add]
or [s_add],
resolver=resolver,
)
if out_combined is None or s_combined is None: # pragma: no cover
# Note that this can only be the case if we were actually going to
# use the schema which we attempted to merge, i.e. prop_name was
# not in the schema and there were unmergable pattern schemas.
return None
m = merged([out_combined, s_combined])
m = merged([out_combined, s_combined], resolver=resolver)
if m is None:
return None
out_props[prop_name] = m
# With all the property names done, it's time to handle the patterns. This is
# simpler as we merge with either an identical pattern, or additionalProperties.
if out_pat or s_pat:
for pattern in set(out_pat) | set(s_pat):
m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)])
m = merged(
[out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)],
resolver=resolver,
)
if m is None: # pragma: no cover
return None
out_pat[pattern] = m
out["patternProperties"] = out_pat
# Finally, we merge togther the additionalProperties schemas.
if out_add or s_add:
m = merged([out_add, s_add])
m = merged([out_add, s_add], resolver=resolver)
if m is None: # pragma: no cover
return None
out["additionalProperties"] = m
Expand Down Expand Up @@ -755,7 +787,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
return None
if "contains" in out and "contains" in s and out["contains"] != s["contains"]:
# If one `contains` schema is a subset of the other, we can discard it.
m = merged([out["contains"], s["contains"]])
m = merged([out["contains"], s["contains"]], resolver=resolver)
if m == out["contains"] or m == s["contains"]:
out["contains"] = m
s.pop("contains")
Expand Down Expand Up @@ -785,7 +817,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
v = {"required": v}
elif isinstance(sval, list):
sval = {"required": sval}
m = merged([v, sval])
m = merged([v, sval], resolver=resolver)
if m is None:
return None
odeps[k] = m
Expand All @@ -799,26 +831,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
[
out.get("additionalItems", TRUTHY),
s.get("additionalItems", TRUTHY),
]
],
resolver=resolver,
)
for a, b in itertools.zip_longest(oitems, sitems):
if a is None:
a = out.get("additionalItems", TRUTHY)
elif b is None:
b = s.get("additionalItems", TRUTHY)
out["items"].append(merged([a, b]))
out["items"].append(merged([a, b], resolver=resolver))
elif isinstance(oitems, list):
out["items"] = [merged([x, sitems]) for x in oitems]
out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems]
out["additionalItems"] = merged(
[out.get("additionalItems", TRUTHY), sitems]
[out.get("additionalItems", TRUTHY), sitems], resolver=resolver
)
elif isinstance(sitems, list):
out["items"] = [merged([x, oitems]) for x in sitems]
out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems]
out["additionalItems"] = merged(
[s.get("additionalItems", TRUTHY), oitems]
[s.get("additionalItems", TRUTHY), oitems], resolver=resolver
)
else:
out["items"] = merged([oitems, sitems])
out["items"] = merged([oitems, sitems], resolver=resolver)
if out["items"] is None:
return None
if isinstance(out["items"], list) and None in out["items"]:
Expand All @@ -842,7 +875,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]:
# If non-validation keys like `title` or `description` don't match,
# that doesn't really matter and we'll just go with first we saw.
return None
out = canonicalish(out)
out = canonicalish(out, resolver=resolver)
if out == FALSEY:
return FALSEY
assert isinstance(out, dict)
Expand Down
Loading