diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index 870daa5f72..1f1fa01845 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -14,7 +14,7 @@ from .dist_utils import get_dist_info from .hooks import HOOKS, Hook from .log_buffer import LogBuffer -from .priority import get_priority +from .priority import Priority, get_priority from .utils import get_time_str @@ -306,6 +306,29 @@ def call_hook(self, fn_name): for hook in self._hooks: getattr(hook, fn_name)(self) + def get_hook_info(self): + # Get hooks info in each stage + stage_hook_map = {stage: [] for stage in Hook.stages} + for hook in self.hooks: + try: + priority = Priority(hook.priority).name + except ValueError: + priority = hook.priority + classname = hook.__class__.__name__ + hook_info = f'({priority:<12}) {classname:<35}' + for trigger_stage in hook.get_triggered_stages(): + stage_hook_map[trigger_stage].append(hook_info) + + stage_hook_infos = [] + for stage in Hook.stages: + hook_infos = stage_hook_map[stage] + if len(hook_infos) > 0: + info = f'{stage}:\n' + info += '\n'.join(hook_infos) + info += '\n -------------------- ' + stage_hook_infos.append(info) + return '\n'.join(stage_hook_infos) + def load_checkpoint(self, filename, map_location='cpu', diff --git a/mmcv/runner/epoch_based_runner.py b/mmcv/runner/epoch_based_runner.py index 1e1de295ed..b95f2a1f68 100644 --- a/mmcv/runner/epoch_based_runner.py +++ b/mmcv/runner/epoch_based_runner.py @@ -101,6 +101,8 @@ def run(self, data_loaders, workflow, max_epochs=None, **kwargs): work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs) self.call_hook('before_run') diff --git a/mmcv/runner/hooks/hook.py b/mmcv/runner/hooks/hook.py index fa8ce4a49f..419f638c5e 100644 --- a/mmcv/runner/hooks/hook.py +++ b/mmcv/runner/hooks/hook.py @@ -1,10 +1,14 @@ # Copyright (c) Open-MMLab. All rights reserved. -from mmcv.utils import Registry +from mmcv.utils import Registry, is_method_overridden HOOKS = Registry('hook') class Hook: + stages = ('before_run', 'before_train_epoch', 'before_train_iter', + 'after_train_iter', 'after_train_epoch', 'before_val_epoch', + 'before_val_iter', 'after_val_iter', 'after_val_epoch', + 'after_run') def before_run(self, runner): pass @@ -65,3 +69,24 @@ def is_last_epoch(self, runner): def is_last_iter(self, runner): return runner.iter + 1 == runner._max_iters + + def get_triggered_stages(self): + trigger_stages = set() + for stage in Hook.stages: + if is_method_overridden(stage, Hook, self): + trigger_stages.add(stage) + + # some methods will be triggered in multi stages + # use this dict to map method to stages. + method_stages_map = { + 'before_epoch': ['before_train_epoch', 'before_val_epoch'], + 'after_epoch': ['after_train_epoch', 'after_val_epoch'], + 'before_iter': ['before_train_iter', 'before_val_iter'], + 'after_iter': ['after_train_iter', 'after_val_iter'], + } + + for method, map_stages in method_stages_map.items(): + if is_method_overridden(method, Hook, self): + trigger_stages.update(map_stages) + + return [stage for stage in Hook.stages if stage in trigger_stages] diff --git a/mmcv/runner/iter_based_runner.py b/mmcv/runner/iter_based_runner.py index 75133d5ec4..62a46216dd 100644 --- a/mmcv/runner/iter_based_runner.py +++ b/mmcv/runner/iter_based_runner.py @@ -108,6 +108,8 @@ def run(self, data_loaders, workflow, max_iters=None, **kwargs): work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) + self.logger.info('Hooks will be executed in the following order:\n%s', + self.get_hook_info()) self.logger.info('workflow: %s, max: %d iters', workflow, self._max_iters) self.call_hook('before_run') diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 8649d9aded..6ca3452409 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -2,10 +2,11 @@ # Copyright (c) Open-MMLab. All rights reserved. from .config import Config, ConfigDict, DictAction from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - import_modules_from_strings, is_list_of, is_seq_of, is_str, - is_tuple_of, iter_cast, list_cast, requires_executable, - requires_package, slice_list, to_1tuple, to_2tuple, - to_3tuple, to_4tuple, to_ntuple, tuple_cast) + import_modules_from_strings, is_list_of, + is_method_overridden, is_seq_of, is_str, is_tuple_of, + iter_cast, list_cast, requires_executable, requires_package, + slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, + to_ntuple, tuple_cast) from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) from .progressbar import (ProgressBar, track_iter_progress, @@ -31,7 +32,8 @@ 'digit_version', 'get_git_hash', 'import_modules_from_strings', 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', - 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple' + 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', + 'is_method_overridden' ] else: from .env import collect_env @@ -60,5 +62,6 @@ 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena', 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', - 'assert_params_all_zeros', 'check_python_script' + 'assert_params_all_zeros', 'check_python_script', + 'is_method_overridden' ] diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py index 1d2517f02d..dee1fa03c9 100644 --- a/mmcv/utils/misc.py +++ b/mmcv/utils/misc.py @@ -333,3 +333,22 @@ def new_func(*args, **kwargs): return new_func return api_warning_wrapper + + +def is_method_overridden(method, base_class, derived_class): + """Check if a method of base class is overridden in derived class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + derived_class (type | Any): the class or instance of the derived class. + """ + assert isinstance(base_class, type), \ + "base_class doesn't accept instance, Please pass class instead." + + if not isinstance(derived_class, type): + derived_class = derived_class.__class__ + + base_method = getattr(base_class, method) + derived_method = getattr(derived_class, method) + return derived_method != base_method diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index 3f9ba7c03a..2cc010617b 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -1070,3 +1070,20 @@ def __init__(self): key_stripped = re.sub(r'^backbone\.', '', key) assert torch.equal(model.state_dict()[key_stripped], state_dict[key]) os.remove(checkpoint_path) + + +def test_get_triggered_stages(): + + class ToyHook(Hook): + # test normal stage + def before_run(): + pass + + # test the method mapped to multi stages. + def after_epoch(): + pass + + hook = ToyHook() + # stages output have order, so here is list instead of set. + expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch'] + assert hook.get_triggered_stages() == expected_stages diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 29819b2faa..7b056554af 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -160,3 +160,33 @@ def test_import_modules_from_strings(): ['os.path', '_not_implemented'], allow_failed_imports=True) assert imported[0] == osp assert imported[1] is None + + +def test_is_method_overridden(): + + class Base: + + def foo1(): + pass + + def foo2(): + pass + + class Sub(Base): + + def foo1(): + pass + + # test passing sub class directly + assert mmcv.is_method_overridden('foo1', Base, Sub) + assert not mmcv.is_method_overridden('foo2', Base, Sub) + + # test passing instance of sub class + sub_instance = Sub() + assert mmcv.is_method_overridden('foo1', Base, sub_instance) + assert not mmcv.is_method_overridden('foo2', Base, sub_instance) + + # base_class should be a class, not instance + base_instance = Base() + with pytest.raises(AssertionError): + mmcv.is_method_overridden('foo1', base_instance, sub_instance)