Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the argument target_wrapper to hydra.utils.instantiate to support recursive type checking #2880

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
from enum import Enum
from textwrap import dedent
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from omegaconf import OmegaConf, SCMode
from omegaconf._utils import is_structured_config
Expand Down Expand Up @@ -145,7 +145,12 @@ def _resolve_target(
return target


def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
def instantiate(
config: Any,
*args: Any,
target_wrapper: Optional[Callable[..., Any]] = None,
**kwargs: Any,
) -> Any:
"""
:param config: An config object describing what to call and what params to use.
In addition to the parameters, the config must contain:
Expand All @@ -168,6 +173,7 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
are converted to dicts / lists too.
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
:param target_wrapper: Optional callable wrap _target_ with before it itself is called.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -224,7 +230,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
_partial_ = config.pop(_Keys.PARTIAL, False)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config,
*args,
recursive=_recursive_,
convert=_convert_,
partial=_partial_,
target_wrapper=target_wrapper,
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
Expand All @@ -247,7 +258,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config,
*args,
recursive=_recursive_,
convert=_convert_,
partial=_partial_,
target_wrapper=target_wrapper,
)
else:
raise InstantiationException(
Expand Down Expand Up @@ -281,6 +297,7 @@ def instantiate_node(
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
target_wrapper: Optional[Callable[..., Any]] = None,
) -> Any:
# Return None if config is None
if node is None or (OmegaConf.is_config(node) and node._is_none()):
Expand Down Expand Up @@ -314,7 +331,12 @@ def instantiate_node(
# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
items = [
instantiate_node(item, convert=convert, recursive=recursive)
instantiate_node(
item,
convert=convert,
recursive=recursive,
target_wrapper=target_wrapper,
)
for item in node._iter_ex(resolve=True)
]

Expand All @@ -331,6 +353,9 @@ def instantiate_node(
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
if target_wrapper:
_target_ = target_wrapper(_target_)

kwargs = {}
is_partial = node.get("_partial_", False) or partial
for key in node.keys():
Expand All @@ -340,7 +365,10 @@ def instantiate_node(
value = node[key]
if recursive:
value = instantiate_node(
value, convert=convert, recursive=recursive
value,
convert=convert,
recursive=recursive,
target_wrapper=target_wrapper,
)
kwargs[key] = _convert_node(value, convert)

Expand Down
Loading