Skip to content

Commit

Permalink
Pass only relevant validate kwargs to Optional(default=...)
Browse files Browse the repository at this point in the history
  • Loading branch information
gschaffner committed Jun 28, 2022
1 parent 09c00ed commit d3a93dc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
31 changes: 24 additions & 7 deletions schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
obtained from config-files, forms, external services or command-line
parsing, converted from JSON/YAML (or something else) to Python data-types."""

import inspect
import re
from inspect import signature

try:
from contextlib import ExitStack
Expand Down Expand Up @@ -273,11 +273,20 @@ def _priority(s):
return COMPARABLE


def _invoke_with_optional_kwargs(f, **kwargs):
s = inspect.signature(f)
if len(s.parameters) == 0:
def _invoke_with_relevant_kwargs(f, **kwargs):
try:
sig = signature(f)
except ValueError as e:
if e.args[0].startswith("no signature found for builtin type"):
builtin = True
else:
raise
else:
builtin = False
if builtin:
return f()
return f(**kwargs)
relevant_param_names = set(sig.parameters.keys()) & set(kwargs.keys())
return f(**{name: kwargs[name] for name in relevant_param_names})


class Schema(object):
Expand Down Expand Up @@ -428,7 +437,11 @@ def validate(self, data, **kwargs):
# Apply default-having optionals that haven't been used:
defaults = set(k for k in s if isinstance(k, Optional) and hasattr(k, "default")) - coverage
for default in defaults:
new[default.key] = _invoke_with_optional_kwargs(default.default, **kwargs) if callable(default.default) else default.default
new[default.key] = (
_invoke_with_relevant_kwargs(default.default, **kwargs)
if callable(default.default)
else default.default
)

return new
if flavor == TYPE:
Expand Down Expand Up @@ -659,7 +672,11 @@ def _get_key_name(key):
sub_schema, is_main_schema=False, description=_get_key_description(key)
)
if isinstance(key, Optional) and hasattr(key, "default"):
expanded_schema[key_name]["default"] = _to_json_type(_invoke_with_optional_kwargs(key.default, **kwargs) if callable(key.default) else key.default)
expanded_schema[key_name]["default"] = _to_json_type(
_invoke_with_relevant_kwargs(key.default, **kwargs)
if callable(key.default)
else key.default
)
elif isinstance(key_name, Or):
# JSON schema does not support having a key named one name or another, so we just add both options
# This is less strict because we cannot enforce that one or the other is required
Expand Down
13 changes: 3 additions & 10 deletions test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def convert(data, increment):
return data + increment
return data

s = {"k": int, "d": {Optional("k", default=lambda **kw: convert(2, kw['increment'])): int, "l": [{"l": [int]}]}}
s = {"k": int, "d": {Optional("k", default=lambda increment: convert(2, increment)): int, "l": [{"l": [int]}]}}
v = {"k": 1, "d": {"l": [{"l": [3, 4, 5]}]}}
d = Schema(s).validate(v, increment=1)
assert d["k"] == 1 and d["d"]["k"] == 3 and d["d"]["l"][0]["l"] == [3, 4, 5]
Expand Down Expand Up @@ -783,12 +783,7 @@ class MyOptional(Optional):
"""
@property
def default(self):

def wrapper(**kwargs):
if 'increment' in kwargs:
return convert(self._default, kwargs['increment'])
return self._default
return wrapper
return lambda increment: convert(self._default, increment)

@default.setter
def default(self, value):
Expand Down Expand Up @@ -1125,9 +1120,7 @@ def default_func():


def test_json_schema_default_is_callable_with_args_passed_from_json_schema():
def default_func(**kwargs):
return 'Hello, ' + kwargs['name']
s = Schema({Optional("test", default=default_func): str})
s = Schema({Optional("test", default=lambda name: "Hello, " + name): str})
assert s.json_schema("my-id", name='World!') == {
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "my-id",
Expand Down

0 comments on commit d3a93dc

Please sign in to comment.