Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/scope #2884

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
from enum import Enum
from textwrap import dedent
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from omegaconf import OmegaConf, SCMode
from omegaconf._utils import is_structured_config
Expand All @@ -18,6 +18,7 @@ class _Keys(str, Enum):
"""Special keys in configs used by instantiate."""

TARGET = "_target_"
SCOPE = "_scope_"
CONVERT = "_convert_"
RECURSIVE = "_recursive_"
ARGS = "_args_"
Expand All @@ -32,6 +33,14 @@ def _is_target(x: Any) -> bool:
return False


def _is_scope(x: Any) -> bool:
if isinstance(x, dict):
return "_scope_" in x
if OmegaConf.is_dict(x):
return "_scope_" in x
return False


def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]:
config_args = kwargs.pop(_Keys.ARGS, ())
output_args = config_args
Expand Down Expand Up @@ -278,6 +287,7 @@ def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any:
def instantiate_node(
node: Any,
*args: Any,
scope: Optional[str] = None,
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
Expand Down Expand Up @@ -314,7 +324,7 @@ def instantiate_node(
# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
items = [
instantiate_node(item, convert=convert, recursive=recursive)
instantiate_node(item, convert=convert, recursive=recursive, scope=scope)
for item in node._iter_ex(resolve=True)
]

Expand All @@ -328,9 +338,27 @@ def instantiate_node(
return lst

elif OmegaConf.is_dict(node):
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
exclude_keys = set({"_target_", "_scope_", "_convert_", "_recursive_", "_partial_"})
if _is_scope(node):
scope = node.pop(_Keys.SCOPE, None)
if scope is not None:
if not isinstance(scope, str):
raise InstantiationException(
f"Scope must be a string, got {type(scope).__name__}"
)
if scope == "":
raise InstantiationException("Scope cannot be an empty string")
return instantiate_node(
node, convert=convert, recursive=recursive, partial=partial, scope=scope
)

if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
_target_ = node.get(_Keys.TARGET)
if scope is not None and _target_.startswith("."):
_target_ = f"{scope}{_target_}"

_target_ = _resolve_target(_target_, full_key)

kwargs = {}
is_partial = node.get("_partial_", False) or partial
for key in node.keys():
Expand All @@ -340,7 +368,7 @@ def instantiate_node(
value = node[key]
if recursive:
value = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive, scope=scope
)
kwargs[key] = _convert_node(value, convert)

Expand All @@ -356,15 +384,15 @@ def instantiate_node(
for key, value in node.items():
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive, scope=scope
)
return dict_items
else:
# Otherwise use DictConfig and resolve interpolations lazily.
cfg = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in node.items():
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive, scope=scope
)
cfg._set_parent(node)
cfg._metadata.object_type = node._metadata.object_type
Expand Down