diff --git a/classy_vision/configuration/__init__.py b/classy_vision/configuration/__init__.py new file mode 100644 index 0000000000..5e8a83391b --- /dev/null +++ b/classy_vision/configuration/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .classy_config_dict import ClassyConfigDict +from .config_error import ConfigError, ConfigUnusedKeysError + +__all__ = ["ClassyConfigDict", "ConfigError", "ConfigUnusedKeysError"] diff --git a/classy_vision/configuration/classy_config_dict.py b/classy_vision/configuration/classy_config_dict.py new file mode 100644 index 0000000000..d89bd9665b --- /dev/null +++ b/classy_vision/configuration/classy_config_dict.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import json +from collections.abc import MutableMapping, Mapping + +from .config_error import ConfigUnusedKeysError + + +class ClassyConfigDict(MutableMapping): + """Mapping which can be made immutable. Also supports tracking unused keys.""" + + def __init__(self, *args, **kwargs): + """Create a ClassyConfigDict. + + Supports the same API as a dict and recursively converts all dicts to + ClassyConfigDicts. + """ + + # NOTE: Another way to implement this would be to subclass dict, but since dict + # is a built-in, it isn't treated like a regular MutableMapping, and calls like + # func(**map) are handled mysteriously, probably interpreter dependent. + # The downside with this implementation is that this isn't a full dict and is + # just a mapping, which means some features like JSON serialization don't work + + self._dict = dict(*args, **kwargs) + self._frozen = False + self._keys_read = set() + for k, v in self._dict.items(): + self._dict[k] = self._from_dict(v) + + @classmethod + def _from_dict(cls, obj): + """Recursively convert all dicts inside obj to ClassyConfigDicts""" + + if isinstance(obj, Mapping): + obj = ClassyConfigDict({k: cls._from_dict(v) for k, v in obj.items()}) + elif isinstance(obj, (list, tuple)): + # tuples are also converted to lists + obj = [cls._from_dict(v) for v in obj] + return obj + + def to_dict(self): + """Return a vanilla Python dict, converting dicts recursively""" + return self._to_dict(self) + + @classmethod + def _to_dict(cls, obj): + """Recursively convert obj to vanilla Python dicts""" + if isinstance(obj, ClassyConfigDict): + obj = {k: cls._to_dict(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + # tuples are also converted to lists + obj = [cls._to_dict(v) for v in obj] + return obj + + def keys(self): + return self._dict.keys() + + def items(self): + self._keys_read.update(self._dict.keys()) + return self._dict.items() + + def values(self): + self._keys_read.update(self._dict.keys()) + return self._dict.values() + + def pop(self, key, default=None): + return self._dict.pop(key, default) + + def popitem(self): + return self._dict.popitem() + + def clear(self): + self._dict.clear() + + def update(self, *args, **kwargs): + if self._frozen: + raise TypeError("Frozen ClassyConfigDicts do not support updates") + self._dict.update(*args, **kwargs) + + def setdefault(self, key, default=None): + return self._dict.setdefault(key, default) + + def __contains__(self, key): + return key in self._dict + + def __eq__(self, obj): + return self._dict == obj + + def __len__(self): + return len(self._dict) + + def __getitem__(self, key): + self._keys_read.add(key) + return self._dict.__getitem__(key) + + def __iter__(self): + return iter(self._dict) + + def __str__(self): + return json.dumps(self.to_dict(), indent=4) + + def __repr__(self): + return repr(self._dict) + + def get(self, key, default=None): + if key in self._dict.keys(): + self._keys_read.add(key) + return self._dict.get(key, default) + + def __copy__(self): + ret = ClassyConfigDict() + for key, value in self._dict.items(): + self._keys_read.add(key) + ret._dict[key] = value + + def copy(self): + return self.__copy__() + + def __deepcopy__(self, memo=None): + # for deepcopies we mark all the keys and sub-keys as read + ret = ClassyConfigDict() + for key, value in self._dict.items(): + self._keys_read.add(key) + ret._dict[key] = copy.deepcopy(value) + return ret + + def __setitem__(self, key, value): + if self._frozen: + raise TypeError("Frozen ClassyConfigDicts do not support assignment") + if isinstance(value, dict) and not isinstance(value, ClassyConfigDict): + value = ClassyConfigDict(value) + self._dict.__setitem__(key, value) + + def __delitem__(self, key): + if self._frozen: + raise TypeError("Frozen ClassyConfigDicts do not support key deletion") + del self._dict[key] + + def _freeze(self, obj): + if isinstance(obj, Mapping): + assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict" + obj._frozen = True + for value in obj.values(): + self._freeze(value) + elif isinstance(obj, list): + for value in obj: + self._freeze(value) + + def _reset_tracking(self, obj): + if isinstance(obj, Mapping): + assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict" + obj._keys_read = set() + for value in obj._dict.values(): + self._reset_tracking(value) + elif isinstance(obj, list): + for value in obj: + self._reset_tracking(value) + + def _unused_keys(self, obj): + unused_keys = [] + if isinstance(obj, Mapping): + assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict" + unused_keys = [key for key in obj._dict.keys() if key not in obj._keys_read] + for key, value in obj._dict.items(): + unused_keys += [ + f"{key}.{subkey}" for subkey in self._unused_keys(value) + ] + elif isinstance(obj, list): + for i, value in enumerate(obj): + unused_keys += [f"{i}.{subkey}" for subkey in self._unused_keys(value)] + return unused_keys + + def freeze(self): + """Freeze the ClassyConfigDict to disallow mutations""" + self._freeze(self) + + def reset_tracking(self): + """Reset key tracking""" + self._reset_tracking(self) + + def unused_keys(self): + """Fetch all the unused keys""" + return self._unused_keys(self) + + def check_unused_keys(self): + """Raise if the config has unused keys""" + unused_keys = self.unused_keys() + if unused_keys: + raise ConfigUnusedKeysError(unused_keys) diff --git a/classy_vision/configuration/config_error.py b/classy_vision/configuration/config_error.py new file mode 100644 index 0000000000..929f279912 --- /dev/null +++ b/classy_vision/configuration/config_error.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + + +class ConfigError(Exception): + pass + + +class ConfigUnusedKeysError(ConfigError): + def __init__(self, unused_keys: List[str]): + self.unused_keys = unused_keys + super().__init__(f"The following keys were unused: {self.unused_keys}") diff --git a/classy_vision/optim/sgd.py b/classy_vision/optim/sgd.py index 1adfecaee1..8cbf219871 100644 --- a/classy_vision/optim/sgd.py +++ b/classy_vision/optim/sgd.py @@ -7,6 +7,7 @@ from typing import Any, Dict import torch.optim +from classy_vision.configuration import ClassyConfigDict from . import ClassyOptimizer, register_optimizer @@ -63,10 +64,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD": config.setdefault("weight_decay", 0.0) config.setdefault("nesterov", False) config.setdefault("use_larc", False) - config.setdefault( - "larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02} - ) - + if config["use_larc"]: + larc_config = ClassyConfigDict(clip=True, eps=1e-8, trust_coefficient=0.02) + else: + larc_config = None + config.setdefault("larc_config", larc_config) assert ( config["momentum"] >= 0.0 and config["momentum"] < 1.0 diff --git a/classy_vision/tasks/__init__.py b/classy_vision/tasks/__init__.py index e4e0a9d25a..7f8db1c238 100644 --- a/classy_vision/tasks/__init__.py +++ b/classy_vision/tasks/__init__.py @@ -6,11 +6,11 @@ from pathlib import Path +from classy_vision.configuration import ClassyConfigDict from classy_vision.generic.registry_utils import import_all_modules from .classy_task import ClassyTask - FILE_ROOT = Path(__file__).parent @@ -26,8 +26,13 @@ def build_task(config): "foo": "bar"}` will find a class that was registered as "my_task" (see :func:`register_task`) and call .from_config on it.""" + config = ClassyConfigDict(config) + task = TASK_REGISTRY[config["name"]].from_config(config) + # at this stage all the configs keys should have been used + config.check_unused_keys() + return task diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index f1e685bf32..0365f481ab 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -494,6 +494,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": Returns: A ClassificationTask instance. """ + test_only = config.get("test_only", False) if not test_only: # TODO Make distinction between epochs and phases in optimizer clear @@ -1252,7 +1253,6 @@ def log_phase_end(self, tag): def __repr__(self): if hasattr(self, "_config"): - config = json.dumps(self._config, indent=4) - return f"{super().__repr__()} initialized with config:\n{config}" + return f"{super().__repr__()} initialized with config:\n{self._config}" return super().__repr__() diff --git a/test/configuration_classy_config_dict_test.py b/test/configuration_classy_config_dict_test.py new file mode 100644 index 0000000000..2bbc3f9119 --- /dev/null +++ b/test/configuration_classy_config_dict_test.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +from classy_vision.configuration import ClassyConfigDict + + +class ClassyConfigDictTest(unittest.TestCase): + def test_dict(self): + d = ClassyConfigDict(a=1, b=[1, 2, "3"]) + d["c"] = [4] + d["d"] = {"a": 2} + self.assertEqual(d, {"a": 1, "b": [1, 2, "3"], "c": [4], "d": {"a": 2}}) + self.assertIsInstance(d, ClassyConfigDict) + self.assertIsInstance(d["d"], ClassyConfigDict) + + def test_freezing(self): + d = ClassyConfigDict(a=1, b=2) + d.freeze() + # resetting an already existing key + with self.assertRaises(TypeError): + d["a"] = 3 + # adding a new key + with self.assertRaises(TypeError): + d["f"] = 3 + + def test_unused_keys(self): + d = ClassyConfigDict( + a=1, + b=[ + 1, + 2, + { + "c": {"a": 2}, + "d": 4, + "e": {"a": 1, "b": 2}, + "f": {"a": 1, "b": {"c": 2}}, + }, + ], + ) + + all_keys = { + "a", + "b", + "b.2.c", + "b.2.c.a", + "b.2.d", + "b.2.e", + "b.2.f", + "b.2.e.a", + "b.2.e.b", + "b.2.f.a", + "b.2.f.b", + "b.2.f.b.c", + } + + def test_func(**kwargs): + return None + + for _ in range(2): + expected_unused_keys = all_keys.copy() + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = d["a"] + expected_unused_keys.remove("a") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = d["b"][2].get("d") + expected_unused_keys.remove("b") + expected_unused_keys.remove("b.2.d") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = d["b"][2]["e"] + expected_unused_keys.remove("b.2.e") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = d["b"][2]["e"].items() + expected_unused_keys.remove("b.2.e.a") + expected_unused_keys.remove("b.2.e.b") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = d["b"][2]["f"] + expected_unused_keys.remove("b.2.f") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + test_func(**d["b"][2]["f"]) + expected_unused_keys.remove("b.2.f.a") + expected_unused_keys.remove("b.2.f.b") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + _ = copy.deepcopy(d) + expected_unused_keys.remove("b.2.c") + expected_unused_keys.remove("b.2.c.a") + expected_unused_keys.remove("b.2.f.b.c") + self.assertSetEqual(set(d.unused_keys()), expected_unused_keys) + + d.reset_tracking() + + def test_to_dict(self): + d = { + "a": 1, + "b": [ + 1, + 2, + { + "c": {"a": 2}, + "d": 4, + "e": {"a": 1, "b": 2}, + "f": {"a": 1, "b": {"c": 2}}, + }, + ], + } + classy_config_dict = ClassyConfigDict(**d) + self.assertEqual(d, classy_config_dict.to_dict()) diff --git a/test/tasks_classification_task_test.py b/test/tasks_classification_task_test.py index 1e708390ab..56024e9306 100644 --- a/test/tasks_classification_task_test.py +++ b/test/tasks_classification_task_test.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from classy_vision.configuration import ConfigUnusedKeysError from classy_vision.dataset import build_dataset from classy_vision.generic.distributed_util import is_distributed_training_run from classy_vision.generic.util import get_checkpoint_dict @@ -92,6 +93,10 @@ def test_build_task(self): task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask)) + config["asd"] = 1 + with self.assertRaises(ConfigUnusedKeysError): + task = build_task(config) + def test_hooks_config_builds_correctly(self): config = get_test_task_config() config["hooks"] = [{"name": "loss_lr_meter_logging"}]