diff --git a/brainglobe_utils/general/numerical.py b/brainglobe_utils/general/numerical.py index d45e60a3..d6d5d02a 100644 --- a/brainglobe_utils/general/numerical.py +++ b/brainglobe_utils/general/numerical.py @@ -1,4 +1,5 @@ import argparse +from typing import Literal def is_even(num): @@ -31,14 +32,17 @@ def is_even(num): return True -def check_positive_float(value, none_allowed=True): +def check_positive_float( + value: float | None | Literal["None", "none"], + none_allowed: bool = True, +) -> float | None: """ Used in argparse to enforce positive floats. Source: https://stackoverflow.com/questions/14117415 Parameters ---------- - value : float + value : float or None Input value. none_allowed : bool, optional @@ -46,8 +50,8 @@ def check_positive_float(value, none_allowed=True): Returns ------- - float - Input value, if it's positive. + float or None + Input value, if it's positive, or None. Raises ------ @@ -55,27 +59,30 @@ def check_positive_float(value, none_allowed=True): If input value is invalid. """ ivalue = value - if ivalue is not None: - ivalue = float(ivalue) - if ivalue < 0: + if value in (None, "None", "none"): + if not none_allowed: + raise argparse.ArgumentTypeError(f"{ivalue} is an invalid value.") + value = None + else: + value = float(value) + if value < 0: raise argparse.ArgumentTypeError( - "%s is an invalid positive value" % value + f"{ivalue} is an invalid positive value" ) - else: - if not none_allowed: - raise argparse.ArgumentTypeError("%s is an invalid value." % value) - return ivalue + return value -def check_positive_int(value, none_allowed=True): +def check_positive_int( + value: int | None | Literal["None", "none"], none_allowed: bool = True +) -> int | None: """ Used in argparse to enforce positive ints. Source: https://stackoverflow.com/questions/14117415 Parameters ---------- - value : int + value : int or None Input value. none_allowed : bool, optional @@ -83,8 +90,8 @@ def check_positive_int(value, none_allowed=True): Returns ------- - int - Input value, if it's positive. + int or None + Input value, if it's positive, or None. Raises ------ @@ -92,14 +99,15 @@ def check_positive_int(value, none_allowed=True): If input value is invalid. """ ivalue = value - if ivalue is not None: - ivalue = int(ivalue) - if ivalue < 0: + if value in (None, "None", "none"): + if not none_allowed: + raise argparse.ArgumentTypeError(f"{ivalue} is an invalid value.") + value = None + else: + value = int(value) + if value < 0: raise argparse.ArgumentTypeError( - "%s is an invalid positive value" % value + f"{ivalue} is an invalid positive value" ) - else: - if not none_allowed: - raise argparse.ArgumentTypeError("%s is an invalid value." % value) - return ivalue + return value diff --git a/brainglobe_utils/general/string.py b/brainglobe_utils/general/string.py index 777c5beb..37ac1cdb 100644 --- a/brainglobe_utils/general/string.py +++ b/brainglobe_utils/general/string.py @@ -1,5 +1,6 @@ +import argparse from pathlib import Path -from typing import Optional +from typing import Literal, Optional from natsort import natsorted @@ -53,3 +54,38 @@ def get_text_lines( if return_lines is not None: lines = lines[return_lines] return lines + + +def check_str( + value: str | None | Literal["None", "none"], none_allowed: bool = True +) -> str | None: + """ + Used in argparse to enforce str input. + + Parameters + ---------- + value : str or None + Input value. + + none_allowed : bool, optional + If False, throw an error for None values. + + Returns + ------- + str or None + Input value, if it's str, or None. + + Raises + ------ + argparse.ArgumentTypeError + If input value is invalid. + """ + ivalue = value + if value in (None, "None", "none"): + if not none_allowed: + raise argparse.ArgumentTypeError(f"{ivalue} is an invalid value.") + value = None + else: + value = str(value) + + return value diff --git a/tests/tests/test_general/test_numerical.py b/tests/tests/test_general/test_numerical.py index 97fd5338..6017698b 100644 --- a/tests/tests/test_general/test_numerical.py +++ b/tests/tests/test_general/test_numerical.py @@ -25,10 +25,10 @@ def test_check_positive_float(): with pytest.raises(ArgumentTypeError): assert numerical.check_positive_float(neg_val) - assert numerical.check_positive_float(None) is None - - with pytest.raises(ArgumentTypeError): - assert numerical.check_positive_float(None, none_allowed=False) + for none_val in [None, "None", "none"]: + assert numerical.check_positive_float(none_val) is None + with pytest.raises(ArgumentTypeError): + assert numerical.check_positive_float(none_val, none_allowed=False) assert numerical.check_positive_float(0) == 0 @@ -42,9 +42,9 @@ def test_check_positive_int(): with pytest.raises(ArgumentTypeError): assert numerical.check_positive_int(neg_val) - assert numerical.check_positive_int(None) is None - - with pytest.raises(ArgumentTypeError): - assert numerical.check_positive_int(None, none_allowed=False) + for none_val in [None, "None", "none"]: + assert numerical.check_positive_int(none_val) is None + with pytest.raises(ArgumentTypeError): + assert numerical.check_positive_int(none_val, none_allowed=False) assert numerical.check_positive_int(0) == 0 diff --git a/tests/tests/test_general/test_string.py b/tests/tests/test_general/test_string.py index eaced219..c6771c2e 100644 --- a/tests/tests/test_general/test_string.py +++ b/tests/tests/test_general/test_string.py @@ -1,3 +1,5 @@ +from argparse import ArgumentTypeError + import pytest from brainglobe_utils.general import string @@ -48,3 +50,13 @@ def test_get_string_lines(jabberwocky, jabberwocky_list): string.get_text_lines(jabberwocky, return_lines=8) == jabberwocky_list[8] ) + + +def test_check_str(): + assert "me" == string.check_str("me") + assert "12" == string.check_str("12") + + for none_val in [None, "None", "none"]: + assert string.check_str(none_val) is None + with pytest.raises(ArgumentTypeError): + assert string.check_str(none_val, none_allowed=False)