Skip to content

Commit

Permalink
support print hooks before running. (#1123)
Browse files Browse the repository at this point in the history
* support print using hooks before running.

* Support to print hook trigger stages.

* Print stage-wise hook infos. And make `stages` as class attribute of
`Hook`

* Add util function `is_method_overriden` and use it in
`Hook.get_trigger_stages`.

* Add unit tests.

* Move `is_method_overriden` to `mmcv/utils/misc.py`

* Improve hook info text.

* Add base_class argument type assertion, and fix some typos.

* Remove `get_trigger_stages` to `get_triggered_stages`

* Use f-string.
  • Loading branch information
mzr1996 authored Jun 25, 2021
1 parent 227e7a7 commit 1b15f02
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 8 deletions.
25 changes: 24 additions & 1 deletion mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import get_priority
from .priority import Priority, get_priority
from .utils import get_time_str


Expand Down Expand Up @@ -306,6 +306,29 @@ def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)

def get_hook_info(self):
# Get hooks info in each stage
stage_hook_map = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name
except ValueError:
priority = hook.priority
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)

stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)

def load_checkpoint(self,
filename,
map_location='cpu',
Expand Down
2 changes: 2 additions & 0 deletions mmcv/runner/epoch_based_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
Expand Down
27 changes: 26 additions & 1 deletion mmcv/runner/hooks/hook.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.utils import Registry
from mmcv.utils import Registry, is_method_overridden

HOOKS = Registry('hook')


class Hook:
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_run')

def before_run(self, runner):
pass
Expand Down Expand Up @@ -65,3 +69,24 @@ def is_last_epoch(self, runner):

def is_last_iter(self, runner):
return runner.iter + 1 == runner._max_iters

def get_triggered_stages(self):
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)

# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
'before_iter': ['before_train_iter', 'before_val_iter'],
'after_iter': ['after_train_iter', 'after_val_iter'],
}

for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)

return [stage for stage in Hook.stages if stage in trigger_stages]
2 changes: 2 additions & 0 deletions mmcv/runner/iter_based_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d iters', workflow,
self._max_iters)
self.call_hook('before_run')
Expand Down
15 changes: 9 additions & 6 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
import_modules_from_strings, is_list_of, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, requires_executable,
requires_package, slice_list, to_1tuple, to_2tuple,
to_3tuple, to_4tuple, to_ntuple, tuple_cast)
import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
Expand All @@ -31,7 +32,8 @@
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple'
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden'
]
else:
from .env import collect_env
Expand Down Expand Up @@ -60,5 +62,6 @@
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script'
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden'
]
19 changes: 19 additions & 0 deletions mmcv/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,22 @@ def new_func(*args, **kwargs):
return new_func

return api_warning_wrapper


def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."

if not isinstance(derived_class, type):
derived_class = derived_class.__class__

base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
17 changes: 17 additions & 0 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,20 @@ def __init__(self):
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)


def test_get_triggered_stages():

class ToyHook(Hook):
# test normal stage
def before_run():
pass

# test the method mapped to multi stages.
def after_epoch():
pass

hook = ToyHook()
# stages output have order, so here is list instead of set.
expected_stages = ['before_run', 'after_train_epoch', 'after_val_epoch']
assert hook.get_triggered_stages() == expected_stages
30 changes: 30 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,33 @@ def test_import_modules_from_strings():
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None


def test_is_method_overridden():

class Base:

def foo1():
pass

def foo2():
pass

class Sub(Base):

def foo1():
pass

# test passing sub class directly
assert mmcv.is_method_overridden('foo1', Base, Sub)
assert not mmcv.is_method_overridden('foo2', Base, Sub)

# test passing instance of sub class
sub_instance = Sub()
assert mmcv.is_method_overridden('foo1', Base, sub_instance)
assert not mmcv.is_method_overridden('foo2', Base, sub_instance)

# base_class should be a class, not instance
base_instance = Base()
with pytest.raises(AssertionError):
mmcv.is_method_overridden('foo1', base_instance, sub_instance)

0 comments on commit 1b15f02

Please sign in to comment.