Skip to content

Commit

Permalink
refactoring: rename some functions (omry#911)
Browse files Browse the repository at this point in the history
* rename _is_union -> is_union_annotation
* rename valid_value_annotation_type -> is_valid_value_annotation
* rename is_primitive_type -> is_primitive_type_annotation
* rename get_ref_type -> get_type_hint
  • Loading branch information
Jasha10 authored May 5, 2022
1 parent c8fc02c commit 32b267d
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 72 deletions.
30 changes: 11 additions & 19 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,7 @@
from contextlib import contextmanager
from enum import Enum
from textwrap import dedent
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
get_type_hints,
)
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

import yaml

Expand Down Expand Up @@ -176,7 +166,7 @@ def _get_class(path: str) -> type:
return klass


def _is_union(type_: Any) -> bool:
def is_union_annotation(type_: Any) -> bool:
return getattr(type_, "__origin__", None) is Union


Expand Down Expand Up @@ -300,7 +290,7 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
value = attrib.default
if value == attr.NOTHING:
value = MISSING
if _is_union(type_):
if is_union_annotation(type_):
e = ConfigValueError(
f"Union types are not supported:\n{name}: {type_str(type_)}"
)
Expand Down Expand Up @@ -332,6 +322,8 @@ def get_dataclass_init_field_names(obj: Any) -> List[str]:
def get_dataclass_data(
obj: Any, allow_objects: Optional[bool] = None
) -> Dict[str, Any]:
from typing import get_type_hints

from omegaconf.omegaconf import MISSING, OmegaConf, _node_wrap

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
Expand All @@ -358,7 +350,7 @@ def get_dataclass_data(
else:
value = MISSING

if _is_union(type_):
if is_union_annotation(type_):
e = ConfigValueError(
f"Union types are not supported:\n{name}: {type_str(type_)}"
)
Expand Down Expand Up @@ -672,11 +664,11 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:
return key_type, element_type


def valid_value_annotation_type(type_: Any) -> bool:
def is_valid_value_annotation(type_: Any) -> bool:
_, type_ = _resolve_optional(type_)
return (
type_ is Any
or is_primitive_type(type_)
or is_primitive_type_annotation(type_)
or is_structured_config(type_)
or is_container_annotation(type_)
)
Expand All @@ -688,7 +680,7 @@ def _valid_dict_key_annotation_type(type_: Any) -> bool:
return type_ is None or type_ is Any or issubclass(type_, DictKeyType.__args__) # type: ignore


def is_primitive_type(type_: Any) -> bool:
def is_primitive_type_annotation(type_: Any) -> bool:
type_ = get_type_of(type_)
return issubclass(type_, Enum) or type_ in (
int,
Expand All @@ -715,7 +707,7 @@ def _get_value(value: Any) -> Any:
return value


def get_ref_type(obj: Any, key: Any = None) -> Optional[Type[Any]]:
def get_type_hint(obj: Any, key: Any = None) -> Optional[Type[Any]]:
from omegaconf import Container, Node

if isinstance(obj, Container):
Expand Down Expand Up @@ -794,7 +786,7 @@ def format_and_raise(
object_type = OmegaConf.get_type(node)
object_type_str = type_str(object_type)

ref_type = get_ref_type(node)
ref_type = get_type_hint(node)
ref_type_str = type_str(ref_type)

msg = string.Template(msg).safe_substitute(
Expand Down
4 changes: 2 additions & 2 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
_is_missing_value,
format_and_raise,
get_value_kind,
is_valid_value_annotation,
split_key,
type_str,
valid_value_annotation_type,
)
from .errors import (
ConfigKeyError,
Expand Down Expand Up @@ -86,7 +86,7 @@ def __post_init__(self) -> None:
self.ref_type = Any
assert self.key_type is Any or isinstance(self.key_type, type)
if self.element_type is not None:
if not valid_value_annotation_type(self.element_type):
if not is_valid_value_annotation(self.element_type):
raise ValidationError(
f"Unsupported value type: '{type_str(self.element_type, include_module_name=True)}'"
)
Expand Down
17 changes: 10 additions & 7 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
_is_missing_value,
_is_none,
_is_special,
_is_union,
_resolve_optional,
get_ref_type,
get_structured_config_data,
get_type_hint,
get_value_kind,
get_yaml_loader,
is_container_annotation,
is_dict_annotation,
is_list_annotation,
is_primitive_dict,
is_primitive_type,
is_primitive_type_annotation,
is_structured_config,
is_tuple_annotation,
is_union_annotation,
)
from .base import Container, ContainerMetadata, DictKeyType, Node, SCMode
from .errors import (
Expand Down Expand Up @@ -106,7 +106,7 @@ def __getstate__(self) -> Dict[str, Any]:
assert False
if sys.version_info < (3, 7): # pragma: no cover
element_type = self._metadata.element_type
if _is_union(element_type):
if is_union_annotation(element_type):
raise OmegaConfBaseException(
"Serializing structured configs with `Union` element type requires python >= 3.7"
)
Expand Down Expand Up @@ -283,7 +283,7 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
assert isinstance(dest, DictConfig)
assert isinstance(src, DictConfig)
src_type = src._metadata.object_type
src_ref_type = get_ref_type(src)
src_ref_type = get_type_hint(src)
assert src_ref_type is not None

# If source DictConfig is:
Expand Down Expand Up @@ -568,7 +568,10 @@ def assign(value_key: Any, val: Node) -> None:
and target_node_ref._has_ref_type()
)
or (target_is_vnode and not isinstance(target_node_ref, AnyNode))
or (isinstance(target_node_ref, AnyNode) and is_primitive_type(value))
or (
isinstance(target_node_ref, AnyNode)
and is_primitive_type_annotation(value)
)
)
if should_set_value:
if special_value and isinstance(value, Node):
Expand Down Expand Up @@ -839,7 +842,7 @@ def _shallow_validate_type_hint(node: Node, type_hint: Any) -> None:
elif vk in (ValueKind.MANDATORY_MISSING, ValueKind.INTERPOLATION):
return
elif vk == ValueKind.VALUE:
if is_primitive_type(ref_type) and isinstance(node, ValueNode):
if is_primitive_type_annotation(ref_type) and isinstance(node, ValueNode):
value = node._value()
if not isinstance(value, ref_type):
raise ValidationError(
Expand Down
6 changes: 4 additions & 2 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ def __init__(
)

def _validate_and_convert_impl(self, value: Any) -> Any:
from ._utils import is_primitive_type
from ._utils import is_primitive_type_annotation

# allow_objects is internal and not an official API. use at your own risk.
# Please be aware that this support is subject to change without notice.
# If this is deemed useful and supportable it may become an official API.

if self._get_flag("allow_objects") is not True and not is_primitive_type(value):
if self._get_flag(
"allow_objects"
) is not True and not is_primitive_type_annotation(value):
t = get_type_of(value)
raise UnsupportedValueType(
f"Value '{t.__name__}' is not a supported primitive type"
Expand Down
28 changes: 14 additions & 14 deletions tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_utils,
flag_override,
)
from omegaconf._utils import _is_optional, get_ref_type
from omegaconf._utils import _is_optional, get_type_hint
from omegaconf.errors import ConfigKeyError, UnsupportedValueType
from tests import IllegalType

Expand Down Expand Up @@ -139,15 +139,15 @@ def test_merge_of_non_subclass_2(self, module: Any, rhs: Any) -> None:
def test_get_type(self, module: Any) -> None:
cfg1 = OmegaConf.create(module.LinkedList)
assert OmegaConf.get_type(cfg1) == module.LinkedList
assert _utils.get_ref_type(cfg1, "next") == Optional[module.LinkedList]
assert _utils.get_type_hint(cfg1, "next") == Optional[module.LinkedList]
assert OmegaConf.get_type(cfg1, "next") is None

assert cfg1.next is None
assert OmegaConf.is_missing(cfg1, "value")

cfg2 = OmegaConf.create(module.MissingTest.Missing1)
assert OmegaConf.is_missing(cfg2, "head")
assert _utils.get_ref_type(cfg2, "head") == module.LinkedList
assert _utils.get_type_hint(cfg2, "head") == module.LinkedList
assert OmegaConf.get_type(cfg2, "head") is None

def test_merge_structured_into_dict(self, module: Any) -> None:
Expand All @@ -164,7 +164,7 @@ def test_merge_structured_into_dict_nested(self, module: Any) -> None:
# type of name becomes str
assert c2 == {"user": {"name": "7", "age": "???"}}
assert isinstance(c2, DictConfig)
assert get_ref_type(c2, "user") == module.User
assert get_type_hint(c2, "user") == module.User

def test_merge_structured_into_dict_nested2(self, module: Any) -> None:
c1 = OmegaConf.create({"user": {"name": IntegerNode(value=7)}})
Expand All @@ -173,7 +173,7 @@ def test_merge_structured_into_dict_nested2(self, module: Any) -> None:
# type of name remains int
assert c2 == {"user": {"name": 7, "age": "???"}}
assert isinstance(c2, DictConfig)
assert get_ref_type(c2, "user") == module.User
assert get_type_hint(c2, "user") == module.User

def test_merge_structured_into_dict_nested3(self, module: Any) -> None:
c1 = OmegaConf.create({"user": {"name": "alice"}})
Expand All @@ -182,7 +182,7 @@ def test_merge_structured_into_dict_nested3(self, module: Any) -> None:
# name is not changed
assert c2 == {"user": {"name": "alice", "age": "???"}}
assert isinstance(c2, DictConfig)
assert get_ref_type(c2, "user") == module.UserWithDefaultName
assert get_type_hint(c2, "user") == module.UserWithDefaultName

def test_merge_missing_object_onto_typed_dictconfig(self, module: Any) -> None:
c1 = OmegaConf.structured(module.DictOfObjects)
Expand All @@ -207,7 +207,7 @@ def test_merge_optional_structured_into_dict(self, module: Any) -> None:
c1 = OmegaConf.create({"user": {"name": "bob"}})
c2 = OmegaConf.merge(c1, module.OptionalUser(module.User(name="alice")))
assert c2.user.name == "alice"
assert get_ref_type(c2, "user") == Optional[module.User]
assert get_type_hint(c2, "user") == Optional[module.User]
assert isinstance(c2, DictConfig)
c2_user = c2._get_node("user")
assert isinstance(c2_user, Node)
Expand All @@ -229,9 +229,9 @@ def test_merge_structured_interpolation_onto_dict(self, module: Any) -> None:
src.user_3 = None
c2 = OmegaConf.merge(c1, src)
assert c2.user_2.name == "bob"
assert get_ref_type(c2, "user_2") == Any
assert get_type_hint(c2, "user_2") == Any
assert c2.user_3 is None
assert get_ref_type(c2, "user_3") == Any
assert get_type_hint(c2, "user_3") == Any

@mark.parametrize("resolve", [True, False])
def test_interpolation_to_structured(self, module: Any, resolve: bool) -> None:
Expand Down Expand Up @@ -263,24 +263,24 @@ def test_plugin_holder(self, module: Any) -> None:
cfg = OmegaConf.create(module.PluginHolder)

assert _is_optional(cfg, "none")
assert _utils.get_ref_type(cfg, "none") == Optional[module.Plugin]
assert _utils.get_type_hint(cfg, "none") == Optional[module.Plugin]
assert OmegaConf.get_type(cfg, "none") is None

assert not _is_optional(cfg, "missing")
assert _utils.get_ref_type(cfg, "missing") == module.Plugin
assert _utils.get_type_hint(cfg, "missing") == module.Plugin
assert OmegaConf.get_type(cfg, "missing") is None

assert not _is_optional(cfg, "plugin")
assert _utils.get_ref_type(cfg, "plugin") == module.Plugin
assert _utils.get_type_hint(cfg, "plugin") == module.Plugin
assert OmegaConf.get_type(cfg, "plugin") == module.Plugin

cfg.plugin = module.ConcretePlugin()
assert not _is_optional(cfg, "plugin")
assert _utils.get_ref_type(cfg, "plugin") == module.Plugin
assert _utils.get_type_hint(cfg, "plugin") == module.Plugin
assert OmegaConf.get_type(cfg, "plugin") == module.ConcretePlugin

assert not _is_optional(cfg, "plugin2")
assert _utils.get_ref_type(cfg, "plugin2") == module.Plugin
assert _utils.get_type_hint(cfg, "plugin2") == module.Plugin
assert OmegaConf.get_type(cfg, "plugin2") == module.ConcretePlugin

def test_plugin_merge(self, module: Any) -> None:
Expand Down
18 changes: 9 additions & 9 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_nested_config_is_none(self, module: Any) -> None:
cfg = OmegaConf.structured(module.NestedWithNone)
assert cfg == {"plugin": None}
assert OmegaConf.get_type(cfg, "plugin") is None
assert _utils.get_ref_type(cfg, "plugin") == Optional[module.Plugin]
assert _utils.get_type_hint(cfg, "plugin") == Optional[module.Plugin]

def test_nested_config(self, module: Any) -> None:
def validate(cfg: DictConfig) -> None:
Expand Down Expand Up @@ -519,11 +519,11 @@ def test_merge_none_is_none(self, module: Any) -> None:

def test_merge_with_subclass_into_missing(self, module: Any) -> None:
base = OmegaConf.structured(module.PluginHolder)
assert _utils.get_ref_type(base, "missing") == module.Plugin
assert _utils.get_type_hint(base, "missing") == module.Plugin
assert OmegaConf.get_type(base, "missing") is None
res = OmegaConf.merge(base, {"missing": module.Plugin})
assert OmegaConf.get_type(res) == module.PluginHolder
assert _utils.get_ref_type(base, "missing") == module.Plugin
assert _utils.get_type_hint(base, "missing") == module.Plugin
assert OmegaConf.get_type(res, "missing") == module.Plugin

def test_merged_with_nons_subclass(self, module: Any) -> None:
Expand Down Expand Up @@ -924,15 +924,15 @@ def test_recursive_list(self, module: Any) -> None:

def test_create_untyped_dict(self, module: Any) -> None:
cfg = OmegaConf.structured(module.UntypedDict)
assert _utils.get_ref_type(cfg, "dict") == Dict[Any, Any]
assert _utils.get_ref_type(cfg, "opt_dict") == Optional[Dict[Any, Any]]
assert _utils.get_type_hint(cfg, "dict") == Dict[Any, Any]
assert _utils.get_type_hint(cfg, "opt_dict") == Optional[Dict[Any, Any]]
assert cfg.dict == {"foo": "var"}
assert cfg.opt_dict is None

def test_create_untyped_list(self, module: Any) -> None:
cfg = OmegaConf.structured(module.UntypedList)
assert _utils.get_ref_type(cfg, "list") == List[Any]
assert _utils.get_ref_type(cfg, "opt_list") == Optional[List[Any]]
assert _utils.get_type_hint(cfg, "list") == List[Any]
assert _utils.get_type_hint(cfg, "opt_list") == Optional[List[Any]]
assert cfg.list == [1, 2]
assert cfg.opt_list is None

Expand Down Expand Up @@ -994,7 +994,7 @@ def test_str2str_as_sub_node(self, module: Any) -> None:
with warns_dict_subclass_deprecated(module.DictSubclass.Str2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str
assert _utils.get_ref_type(cfg.foo) == Any
assert _utils.get_type_hint(cfg.foo) == Any

cfg.foo.hello = "world"
assert cfg.foo.hello == "world"
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_int2str_as_sub_node(self, module: Any) -> None:
with warns_dict_subclass_deprecated(module.DictSubclass.Int2Str):
cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str})
assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str
assert _utils.get_ref_type(cfg.foo) == Any
assert _utils.get_type_hint(cfg.foo) == Any

cfg.foo[10] = "ten"
assert cfg.foo[10] == "ten"
Expand Down
Loading

0 comments on commit 32b267d

Please sign in to comment.