Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Implement config validation to find unused keys (#665)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #665

Implement a `ClassyConfigDict` type which supports tracking reads and freezing the map (the latter is unused currently).

Added it to `build_task` to catch cases where we don't use any keys passed by users.

This will not catch all instances, like when some components do a deepcopy - we assume all the keys and sub-keys are read in such a situation

Differential Revision: D25321360

fbshipit-source-id: d5bd63c5340575171a1847739025eea7aec576f1
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Jan 23, 2021
1 parent 8592b83 commit f5ea2af
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 7 deletions.
9 changes: 9 additions & 0 deletions classy_vision/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .classy_config_dict import ClassyConfigDict
from .config_error import ConfigError, ConfigUnusedKeysError

__all__ = ["ClassyConfigDict", "ConfigError", "ConfigUnusedKeysError"]
180 changes: 180 additions & 0 deletions classy_vision/configuration/classy_config_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import json
from collections.abc import MutableMapping, Mapping

from .config_error import ConfigUnusedKeysError


class ClassyConfigDict(MutableMapping):
"""Mapping which can be made immutable. Also supports tracking unused keys."""

def __init__(self, *args, **kwargs):
"""Create a ClassyConfigDict.
Supports the same API as a dict and recursively converts all dicts to
ClassyConfigDicts.
"""

# NOTE: Another way to implement this would be to subclass dict, but since dict
# is a built-in, it isn't treated like a regular MutableMapping, and calls like
# func(**map) are handled mysteriously, probably interpreter dependent.
# The downside with this implementation is that this isn't a full dict and is
# just a mapping, which means some features like JSON serialization don't work

self._dict = dict(*args, **kwargs)
self._frozen = False
self._keys_read = set()
for k, v in self._dict.items():
self._dict[k] = self._from_dict(v)

@classmethod
def _from_dict(cls, obj):
"""Recursively convert all dicts inside obj to ClassyConfigDicts"""

if isinstance(obj, Mapping):
obj = ClassyConfigDict({k: cls._from_dict(v) for k, v in obj.items()})
elif isinstance(obj, (list, tuple)):
# tuples are also converted to lists
obj = [cls._from_dict(v) for v in obj]
return obj

def keys(self):
return self._dict.keys()

def items(self):
self._keys_read.update(self._dict.keys())
return self._dict.items()

def values(self):
self._keys_read.update(self._dict.keys())
return self._dict.values()

def pop(self, key, default=None):
return self._dict.pop(key, default)

def popitem(self):
return self._dict.popitem()

def clear(self):
self._dict.clear()

def update(self, *args, **kwargs):
if self._frozen:
raise TypeError("Frozen ClassyConfigDicts do not support updates")
self._dict.update(*args, **kwargs)

def setdefault(self, key, default=None):
return self._dict.setdefault(key, default)

def __contains__(self, key):
return key in self._dict

def __eq__(self, obj):
return self._dict == obj

def __len__(self):
return len(self._dict)

def __getitem__(self, key):
self._keys_read.add(key)
return self._dict.__getitem__(key)

def __iter__(self):
return iter(self._dict)

def __str__(self):
return json.dumps(self.to_dict(), indent=4)

def __repr__(self):
return repr(self._dict)

def get(self, key, default=None):
if key in self._dict.keys():
self._keys_read.add(key)
return self._dict.get(key, default)

def __copy__(self):
ret = ClassyConfigDict()
for key, value in self._dict.items():
self._keys_read.add(key)
ret._dict[key] = value

def copy(self):
return self.__copy__()

def __deepcopy__(self, memo=None):
# for deepcopies we mark all the keys and sub-keys as read
ret = ClassyConfigDict()
for key, value in self._dict.items():
self._keys_read.add(key)
ret._dict[key] = copy.deepcopy(value)
return ret

def __setitem__(self, key, value):
if self._frozen:
raise TypeError("Frozen ClassyConfigDicts do not support assignment")
if isinstance(value, dict) and not isinstance(value, ClassyConfigDict):
value = ClassyConfigDict(value)
self._dict.__setitem__(key, value)

def __delitem__(self, key):
if self._frozen:
raise TypeError("Frozen ClassyConfigDicts do not support key deletion")
del self._dict[key]

def _freeze(self, obj):
if isinstance(obj, Mapping):
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
obj._frozen = True
for value in obj.values():
self._freeze(value)
elif isinstance(obj, list):
for value in obj:
self._freeze(value)

def _reset_tracking(self, obj):
if isinstance(obj, Mapping):
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
obj._keys_read = set()
for value in obj._dict.values():
self._reset_tracking(value)
elif isinstance(obj, list):
for value in obj:
self._reset_tracking(value)

def _unused_keys(self, obj):
unused_keys = []
if isinstance(obj, Mapping):
assert isinstance(obj, ClassyConfigDict), f"{obj} is not a ClassyConfigDict"
unused_keys = [key for key in obj._dict.keys() if key not in obj._keys_read]
for key, value in obj._dict.items():
unused_keys += [
f"{key}.{subkey}" for subkey in self._unused_keys(value)
]
elif isinstance(obj, list):
for i, value in enumerate(obj):
unused_keys += [f"{i}.{subkey}" for subkey in self._unused_keys(value)]
return unused_keys

def freeze(self):
"""Freeze the ClassyConfigDict to disallow mutations"""
self._freeze(self)

def reset_tracking(self):
"""Reset key tracking"""
self._reset_tracking(self)

def unused_keys(self):
"""Fetch all the unused keys"""
return self._unused_keys(self)

def check_unused_keys(self):
"""Raise if the config has unused keys"""
unused_keys = self.unused_keys()
if unused_keys:
raise ConfigUnusedKeysError(unused_keys)
16 changes: 16 additions & 0 deletions classy_vision/configuration/config_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List


class ConfigError(Exception):
pass


class ConfigUnusedKeysError(ConfigError):
def __init__(self, unused_keys: List[str]):
self.unused_keys = unused_keys
super().__init__(f"The following keys were unused: {self.unused_keys}")
10 changes: 6 additions & 4 deletions classy_vision/optim/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict

import torch.optim
from classy_vision.configuration import ClassyConfigDict

from . import ClassyOptimizer, register_optimizer

Expand Down Expand Up @@ -63,10 +64,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
config.setdefault("weight_decay", 0.0)
config.setdefault("nesterov", False)
config.setdefault("use_larc", False)
config.setdefault(
"larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02}
)

if config["use_larc"]:
larc_config = ClassyConfigDict(clip=True, eps=1e-8, trust_coefficient=0.02)
else:
larc_config = None
config.setdefault("larc_config", larc_config)
assert (
config["momentum"] >= 0.0
and config["momentum"] < 1.0
Expand Down
7 changes: 6 additions & 1 deletion classy_vision/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from pathlib import Path

from classy_vision.configuration import ClassyConfigDict
from classy_vision.generic.registry_utils import import_all_modules

from .classy_task import ClassyTask


FILE_ROOT = Path(__file__).parent


Expand All @@ -26,8 +26,13 @@ def build_task(config):
"foo": "bar"}` will find a class that was registered as "my_task"
(see :func:`register_task`) and call .from_config on it."""

config = ClassyConfigDict(config)

task = TASK_REGISTRY[config["name"]].from_config(config)

# at this stage all the configs keys should have been used
config.check_unused_keys()

return task


Expand Down
4 changes: 2 additions & 2 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
Returns:
A ClassificationTask instance.
"""

test_only = config.get("test_only", False)
if not test_only:
# TODO Make distinction between epochs and phases in optimizer clear
Expand Down Expand Up @@ -1252,7 +1253,6 @@ def log_phase_end(self, tag):

def __repr__(self):
if hasattr(self, "_config"):
config = json.dumps(self._config, indent=4)
return f"{super().__repr__()} initialized with config:\n{config}"
return f"{super().__repr__()} initialized with config:\n{self._config}"

return super().__repr__()
101 changes: 101 additions & 0 deletions test/configuration_classy_config_dict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

from classy_vision.configuration import ClassyConfigDict


class ClassyConfigDictTest(unittest.TestCase):
def test_dict(self):
d = ClassyConfigDict(a=1, b=[1, 2, "3"])
d["c"] = [4]
d["d"] = {"a": 2}
self.assertEqual(d, {"a": 1, "b": [1, 2, "3"], "c": [4], "d": {"a": 2}})
self.assertIsInstance(d, ClassyConfigDict)
self.assertIsInstance(d["d"], ClassyConfigDict)

def test_freezing(self):
d = ClassyConfigDict(a=1, b=2)
d.freeze()
# resetting an already existing key
with self.assertRaises(TypeError):
d["a"] = 3
# adding a new key
with self.assertRaises(TypeError):
d["f"] = 3

def test_unused_keys(self):
d = ClassyConfigDict(
a=1,
b=[
1,
2,
{
"c": {"a": 2},
"d": 4,
"e": {"a": 1, "b": 2},
"f": {"a": 1, "b": {"c": 2}},
},
],
)

all_keys = {
"a",
"b",
"b.2.c",
"b.2.c.a",
"b.2.d",
"b.2.e",
"b.2.f",
"b.2.e.a",
"b.2.e.b",
"b.2.f.a",
"b.2.f.b",
"b.2.f.b.c",
}

def test_func(**kwargs):
return None

for _ in range(2):
expected_unused_keys = all_keys.copy()
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = d["a"]
expected_unused_keys.remove("a")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = d["b"][2].get("d")
expected_unused_keys.remove("b")
expected_unused_keys.remove("b.2.d")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = d["b"][2]["e"]
expected_unused_keys.remove("b.2.e")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = d["b"][2]["e"].items()
expected_unused_keys.remove("b.2.e.a")
expected_unused_keys.remove("b.2.e.b")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = d["b"][2]["f"]
expected_unused_keys.remove("b.2.f")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

test_func(**d["b"][2]["f"])
expected_unused_keys.remove("b.2.f.a")
expected_unused_keys.remove("b.2.f.b")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

_ = copy.deepcopy(d)
expected_unused_keys.remove("b.2.c")
expected_unused_keys.remove("b.2.c.a")
expected_unused_keys.remove("b.2.f.b.c")
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)

d.reset_tracking()
5 changes: 5 additions & 0 deletions test/tasks_classification_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import torch.nn as nn
from classy_vision.configuration import ConfigUnusedKeysError
from classy_vision.dataset import build_dataset
from classy_vision.generic.distributed_util import is_distributed_training_run
from classy_vision.generic.util import get_checkpoint_dict
Expand Down Expand Up @@ -92,6 +93,10 @@ def test_build_task(self):
task = build_task(config)
self.assertTrue(isinstance(task, ClassificationTask))

config["asd"] = 1
with self.assertRaises(ConfigUnusedKeysError):
task = build_task(config)

def test_hooks_config_builds_correctly(self):
config = get_test_task_config()
config["hooks"] = [{"name": "loss_lr_meter_logging"}]
Expand Down

0 comments on commit f5ea2af

Please sign in to comment.