Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelhball committed Feb 7, 2023
1 parent 8bbd3a1 commit e94ac76
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
20 changes: 13 additions & 7 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,18 +354,24 @@ def instantiate_node(
):
dict_items = {}
for key, value in node.items():
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
)
if recursive:
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
)
else:
dict_items[key] = value
return dict_items
else:
# Otherwise use DictConfig and resolve interpolations lazily.
cfg = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in node.items():
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
)
if recursive:
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
)
else:
cfg[key] = value
cfg._set_parent(node)
cfg._metadata.object_type = node._metadata.object_type
if convert == ConvertMode.OBJECT:
Expand Down
5 changes: 5 additions & 0 deletions tests/instantiate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ class NestedConf:
b: Any = field(default_factory=lambda: User(name="b", age=2))


@dataclass
class NestedConfNoTarget:
a: Any = field(default_factory=lambda: SimpleClassDefaultPrimitiveConf)


def recisinstance(got: Any, expected: Any) -> bool:
"""Compare got with expected type, recursively on dict and list."""
if not isinstance(got, type(expected)):
Expand Down
11 changes: 11 additions & 0 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Mapping,
MappingConf,
NestedConf,
NestedConfNoTarget,
NestingClass,
OuterClass,
Parameters,
Expand Down Expand Up @@ -2096,3 +2097,13 @@ class DictValuesConf:
cfg = OmegaConf.structured(DictValuesConf)
obj = instantiate_func(config=cfg)
assert obj.d is None



def test_non_target_recursive(instantiate_func: Any) -> None:
cfg = OmegaConf.structured(NestedConfNoTarget)
obj = instantiate_func(config=cfg, _recursive_=False)
assert isinstance(obj.a, DictConfig)

obj = instantiate_func(config=cfg, _recursive_=True)
assert isinstance(obj.a, SimpleClass)

0 comments on commit e94ac76

Please sign in to comment.