diff --git a/ding/utils/deprecation.py b/ding/utils/deprecation.py new file mode 100644 index 0000000000..b71c14977f --- /dev/null +++ b/ding/utils/deprecation.py @@ -0,0 +1,52 @@ +import functools +import textwrap +import warnings +from typing import Optional + + +def deprecated(since: str, removed_in: str, up_to: Optional[str] = None): + """ + Overview: + Decorate a function to signify its deprecation. + Arguments: + - since (:obj:`str`): the version when the function was first deprecated. + - removed_in (:obj:`str`): the version when the function will be removed. + - up_to (:obj:`Optional[str]`): the new API users should use. + Returns: + - decorator (:obj:`Callable`): decorated function. + Examples: + >>> from ding.utils.deprecation import deprecated + >>> @deprecated('0.4.1', '0.5.1') + >>> def hello(): + >>> print('hello') + """ + + def decorator(func): + existing_docstring = func.__doc__ or "" + + deprecated_doc = f'.. deprecated:: {since}\n Deprecated and will be removed in version {removed_in}' + + if up_to is not None: + deprecated_doc += f', please use `{up_to}` instead.' + else: + deprecated_doc += '.' + + func.__doc__ = deprecated_doc + "\n\n" + textwrap.dedent(existing_docstring) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + warning_msg = ( + f'API `{func.__module__}.{func.__name__}` is deprecated since version {since} ' + f'and will be removed in version {removed_in}' + ) + if up_to is not None: + warning_msg += f", please use `{up_to}` instead." + else: + warning_msg += "." + + warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/ding/utils/tests/test_deprecation.py b/ding/utils/tests/test_deprecation.py new file mode 100644 index 0000000000..c5304825a2 --- /dev/null +++ b/ding/utils/tests/test_deprecation.py @@ -0,0 +1,28 @@ +import pytest +import warnings +from ding.utils.deprecation import deprecated + + +@pytest.mark.unittest +def test_deprecated(): + + @deprecated('0.4.1', '0.5.1') + def deprecated_func1(): + pass + + @deprecated('0.4.1', '0.5.1', 'deprecated_func3') + def deprecated_func2(): + pass + + with warnings.catch_warnings(record=True) as w: + deprecated_func1() + assert ( + 'API `test_deprecation.deprecated_func1` is deprecated ' + 'since version 0.4.1 and will be removed in version 0.5.1.' + ) == str(w[-1].message) + deprecated_func2() + assert ( + 'API `test_deprecation.deprecated_func2` is deprecated ' + 'since version 0.4.1 and will be removed in version 0.5.1, ' + 'please use `deprecated_func3` instead.' + ) == str(w[-1].message)