Skip to content

Commit

Permalink
Merge pull request #176 from fuhrysteve/nested_lambda
Browse files Browse the repository at this point in the history
🐛 Fix `AttributeError` when nesting schemas using `lambda: Schema()`
  • Loading branch information
fuhrysteve authored Mar 14, 2023
2 parents c3549be + 77db62e commit 7446d3d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
17 changes: 14 additions & 3 deletions marshmallow_jsonschema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def get_properties(self, obj) -> typing.Dict[str, typing.Dict[str, typing.Any]]:
if self.props_ordered:
fields_items_sequence = obj.fields.items()
else:
fields_items_sequence = sorted(obj.fields.items())
if callable(obj):
fields_items_sequence = sorted(obj().fields.items())
else:
fields_items_sequence = sorted(obj.fields.items())

for field_name, field in fields_items_sequence:
schema = self._get_schema_for_field(obj, field)
Expand All @@ -172,8 +175,11 @@ def get_properties(self, obj) -> typing.Dict[str, typing.Dict[str, typing.Any]]:
def get_required(self, obj) -> typing.Union[typing.List[str], _Missing]:
"""Fill out required field."""
required = []

for field_name, field in sorted(obj.fields.items()):
if callable(obj):
field_items_iterable = sorted(obj().fields.items())
else:
field_items_iterable = sorted(obj.fields.items())
for field_name, field in field_items_iterable:
if field.required:
required.append(field.data_key or field.name)

Expand Down Expand Up @@ -294,6 +300,11 @@ def _from_nested_schema(self, obj, field):
exclude = field.exclude
nested_cls = nested
nested_instance = nested(only=only, exclude=exclude, context=obj.context)
elif callable(nested):
nested_instance = nested()
nested_type = type(nested_instance)
name = nested_type.__name__
nested_cls = nested_type.__class__
else:
nested_cls = nested.__class__
name = nested_cls.__name__
Expand Down
27 changes: 27 additions & 0 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,30 @@ class TestSchema(Schema):
assert (
len(data["definitions"]["TestSchema"]["properties"]["union_prop"]["anyOf"]) == 3
)


def test_dumping_recursive_schema():
"""
this reproduces issue https://github.com/fuhrysteve/marshmallow-jsonschema/issues/164
"""
json_schema = JSONSchema()

def generate_recursive_schema_with_name():
class RecursiveSchema(Schema):
# when nesting recursively you can either refer the recursive schema by its name
nested_mwe_recursive = fields.Nested("RecursiveSchema")

return json_schema.dump(RecursiveSchema())

def generate_recursive_schema_with_lambda():
class RecursiveSchema(Schema):
# or you can use a lambda (as suggested in the marshmallow docs)
nested_mwe_recursive = fields.Nested(lambda: RecursiveSchema())

return json_schema.dump(
RecursiveSchema()
) # this shall _not_ raise an AttributeError

lambda_schema = generate_recursive_schema_with_lambda()
name_schema = generate_recursive_schema_with_name()
assert lambda_schema == name_schema

0 comments on commit 7446d3d

Please sign in to comment.