From 2447daec00b5759a58e938096327fc5a6d8f2644 Mon Sep 17 00:00:00 2001 From: WassCodeur Date: Tue, 21 May 2024 14:35:41 +0000 Subject: [PATCH] NF: Add keyword_only decorator to enforce keyword-only arguments --- fury/decorators.py | 109 ++++++++++++++++++++++++++++++++++ fury/tests/test_decorators.py | 19 +++++- 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/fury/decorators.py b/fury/decorators.py index 4cb1ee2cf..185cd6a4b 100644 --- a/fury/decorators.py +++ b/fury/decorators.py @@ -1,5 +1,7 @@ """Decorators for FURY tests.""" +from functools import wraps +from inspect import signature import platform import re import sys @@ -43,3 +45,110 @@ def doctest_skip_parser(func): new_lines.append(code) func.__doc__ = "\n".join(new_lines) return func + + +def keyword_only(func): + """A decorator to enforce keyword-only arguments. + + This decorator is used to enforce that certain arguments of a function + are passed as keyword arguments. This is useful to prevent users from + passing arguments in the wrong order. + + Parameters + ---------- + func : callable + The function to decorate. + + Returns + ------- + callable + The decorated function. + + Examples + -------- + >>> @keyword_only + ... def add(*, a, b): + ... return a + b + >>> add(a=1, b=2) + 3 + >>> add(b=2, a=1, c=3) + Traceback (most recent call last): + ... + TypeError: add() got an unexpected keyword arguments: c + Usage: add(a=[your_value], b=[your_value]) + Please Provide keyword-only arguments: a=[your_value], b=[your_value] + >>> add(1, 2) + Traceback (most recent call last): + ... + TypeError: add() takes 0 positional arguments but 2 were given + Usage: add(a=[your_value], b=[your_value]) + Please Provide keyword-only arguments: a=[your_value], b=[your_value] + >>> add(a=1) + Traceback (most recent call last): + ... + TypeError: add() missing 1 required keyword-only arguments: b + Usage: add(a=[your_value], b=[your_value]) + Please Provide keyword-only arguments: a=[your_value], b=[your_value] + """ + + @wraps(func) + def wrapper(*args, **kwargs): + sig = signature(func) + params = sig.parameters + missing_params = [ + arg.name + for arg in params.values() + if arg.name not in kwargs and arg.kind == arg.KEYWORD_ONLY + ] + params_sample = [ + f"{arg}=[your_value]" + for arg in params.values() + if arg.kind == arg.KEYWORD_ONLY + ] + params_sample_str = ", ".join(params_sample) + unexpected_params_list = [arg for arg in kwargs if arg not in params] + unexpected_params = ", ".join(unexpected_params_list) + if args: + raise TypeError( + ( + "{}() takes 0 positional arguments but {} were given\n" + "Usage: {}({})\n" + "Please Provide keyword-only arguments: {}" + ).format( + func.__name__, + len(args), + func.__name__, + params_sample_str, + params_sample_str, + ) + ) + else: + if unexpected_params: + raise TypeError( + "{}() got an unexpected keyword arguments: {}\n" + "Usage: {}({})\n" + "Please Provide keyword-only arguments: {}".format( + func.__name__, + unexpected_params, + func.__name__, + params_sample_str, + params_sample_str, + ) + ) + + elif missing_params: + raise TypeError( + "{}() missing {} required keyword-only arguments: {}\n" + "Usage: {}({})\n" + "Please Provide keyword-only arguments: {}".format( + func.__name__, + len(missing_params), + ", ".join(missing_params), + func.__name__, + params_sample_str, + params_sample_str, + ) + ) + return func(*args, **kwargs) + + return wrapper diff --git a/fury/tests/test_decorators.py b/fury/tests/test_decorators.py index 914424484..1842bdb59 100644 --- a/fury/tests/test_decorators.py +++ b/fury/tests/test_decorators.py @@ -2,7 +2,7 @@ import numpy.testing as npt -from fury.decorators import doctest_skip_parser +from fury.decorators import doctest_skip_parser, keyword_only from fury.testing import assert_true HAVE_AMODULE = False @@ -49,3 +49,20 @@ def f(): del HAVE_AMODULE f.__doc__ = docstring npt.assert_raises(NameError, doctest_skip_parser, f) + + +def test_keyword_only(): + @keyword_only + def f(*, a, b): + return a + b + + npt.assert_equal(f(a=1, b=2), 3) + npt.assert_raises(TypeError, f, a=1, b=2, c=2) + npt.assert_raises(TypeError, f, 1, 2) + npt.assert_raises(TypeError, f, 1, b=2) + npt.assert_raises(TypeError, f, a=1, b=2) + npt.assert_raises( + TypeError, + f, + a=1, + )