Skip to content

Commit c44c8bf

Browse files
NaorHabaNaor HabaRoyToluna
authored
Enhancement#5 add threshold checker (#6)
* added threshold checker * updated Pipfile * changed threshold to greater (instead of greater-equal) * added tests for threshold_checker * updated pipfile * updated threshold_checker to work with __call__ * * added monitor_mode to base class * moved input validation to base class * name fix for utils test class * attached tests to main.py Co-authored-by: Naor Haba <[email protected]> Co-authored-by: Roy Sadaka <[email protected]>
1 parent 8476f2a commit c44c8bf

File tree

7 files changed

+140
-15
lines changed

7 files changed

+140
-15
lines changed

Pipfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ tensorboard = "==2.3.0"
88
tqdm = "==4.51.0"
99

1010
[packages]
11-
numpy = "==1.19.2"
11+
numpy = "*"
1212
torch = "*"
1313
torchvision = "*"
1414
protobuf = "==3.20.*"

lpd/callbacks/callback_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class CallbackBase():
5050
State.EXTERNAL
5151
Phase.PREDICT_END
5252
53-
Agrs:
53+
Args:
5454
apply_on_phase - (lpd.enums.Phase) the phase to invoke this callback
5555
apply_on_states - (lpd.enums.State) state or list of states to invoke this parameter (under the relevant phase), None will invoke it on all states
5656
round_values_on_print_to - optional, it will round the numerical values in the prints

lpd/callbacks/callback_monitor.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,34 @@
44
from math import inf
55
import torch
66

7-
class CallbackMonitor():
7+
from lpd.utils.threshold_checker import ThresholdChecker, AbsoluteThresholdChecker
8+
9+
10+
class CallbackMonitor:
811
"""
912
Will check if the desired metric improved with support for patience
10-
Agrs:
13+
Args:
1114
patience - int or None (will be set to inf) track how many epochs/iterations without improvements in monitoring
1215
(negative number will set to inf)
1316
monitor_type - e.g lpd.enums.MonitorType.LOSS
1417
stats_type - e.g lpd.enums.StatsType.VAL
1518
monitor_mode - e.g. lpd.enums.MonitorMode.MIN, min wothh check if the metric decreased, MAX will check for increase
1619
metric_name - in case of monitor_mode=lpd.enums.MonitorMode.METRIC, provide metric_name, otherwise, leave it None
1720
"""
18-
def __init__(self, monitor_type: MonitorType, stats_type: StatsType, monitor_mode: MonitorMode, patience: int=None, metric_name: Optional[str]=None):
21+
def __init__(self, monitor_type: MonitorType, stats_type: StatsType, monitor_mode: MonitorMode,
22+
threshold_checker: Optional[ThresholdChecker] = None, patience: int=None, metric_name: Optional[str]=None):
1923
self.patience = inf if patience is None or patience < 0 else patience
2024
self.patience_countdown = self.patience
2125
self.monitor_type = monitor_type
2226
self.stats_type = stats_type
2327
self.monitor_mode = monitor_mode
28+
self.threshold_checker = AbsoluteThresholdChecker(monitor_mode) if threshold_checker is None else threshold_checker
2429
self.metric_name = metric_name
2530
self.minimum = torch.tensor(inf)
2631
self.maximum = torch.tensor(-inf)
2732
self.previous = self._get_best()
2833
self.description = self._get_description()
29-
self._track_invoked = False
34+
self._track_invoked = False
3035

3136
def _get_description(self):
3237
desc = f'{self.monitor_mode}_{self.stats_type}_{self.monitor_type}'
@@ -82,29 +87,28 @@ def track(self, callback_context: CallbackContext):
8287

8388
if len(value_to_consider.shape) == 0 or \
8489
(len(value_to_consider.shape) == 1 and value_to_consider.shape[0] == 1):
85-
if self.monitor_mode == MonitorMode.MIN and value_to_consider < curr_minimum or \
86-
self.monitor_mode == MonitorMode.MAX and value_to_consider > curr_maximum:
90+
if self.threshold_checker(new_value=value_to_consider, old_value=curr_best):
8791
did_improve = True
8892
self.patience_countdown = self.patience
8993
else:
9094
if self.patience != inf:
9195
raise ValueError("[CallbackMonitor] - can't monitor patience for metric that has multiple values")
92-
93-
return CallbackMonitorResult(did_improve=did_improve,
94-
new_value=value_to_consider,
96+
97+
return CallbackMonitorResult(did_improve=did_improve,
98+
new_value=value_to_consider,
9599
prev_value=curr_previous,
96100
new_best=new_best,
97101
prev_best=curr_best,
98102
change_from_previous=change_from_previous,
99103
change_from_best=change_from_best,
100-
patience_left=self.patience_countdown,
104+
patience_left=self.patience_countdown,
101105
description=self.description,
102106
name = name)
103107

104108

105109
class CallbackMonitorResult():
106-
def __init__(self, did_improve: bool,
107-
new_value: float,
110+
def __init__(self, did_improve: bool,
111+
new_value: float,
108112
prev_value: float,
109113
new_best: float,
110114
prev_best: float,

lpd/callbacks/scheduler_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
class SchedulerStep(CallbackBase):
99
"""This callback will invoke a "step()" on the scheduler.
1010
11-
Agrs:
11+
Args:
1212
apply_on_phase - see in CallbackBase
1313
apply_on_states - see in CallbackBase
1414
scheduler_parameters_func - Since some schedulers takes parameters in step(param1, param2...)

lpd/utils/threshold_checker.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Union
3+
from torch import Tensor
4+
from lpd.enums import MonitorMode
5+
6+
7+
class ThresholdChecker(ABC):
8+
"""
9+
Check if the current value is better than the previous best value according to different threshold criteria
10+
This is an abstract class meant to be inherited by different threshold checkers
11+
Can also be inherited by the user to create custom threshold checkers
12+
"""
13+
def __init__(self, monitor_mode: MonitorMode, threshold: float):
14+
self.monitor_mode = monitor_mode
15+
self.threshold = threshold
16+
17+
def validate_input(self):
18+
if self.threshold < 0:
19+
raise ValueError(f"Threshold must be non-negative, but got {self.threshold}")
20+
21+
@abstractmethod
22+
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
23+
pass
24+
25+
26+
class AbsoluteThresholdChecker(ThresholdChecker):
27+
"""
28+
A threshold checker that checks if the difference between the current value and the previous best value
29+
is greater than or equal to a given threshold
30+
31+
Args:
32+
monitor_mode: MIN or MAX
33+
threshold - the threshold to check (must be non-negative)
34+
"""
35+
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0):
36+
super(AbsoluteThresholdChecker, self).__init__(monitor_mode, threshold)
37+
38+
def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
39+
return old_value - new_value > self.threshold
40+
41+
def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
42+
return new_value - old_value > self.threshold
43+
44+
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
45+
if self.monitor_mode == MonitorMode.MIN:
46+
return self._is_new_value_lower(new_value, old_value)
47+
if self.monitor_mode == MonitorMode.MAX:
48+
return self._is_new_value_higher(new_value, old_value)
49+
50+
51+
class RelativeThresholdChecker(ThresholdChecker):
52+
"""
53+
A threshold checker that checks if the relative difference between the current value and the previous best value
54+
is greater than or equal to a given threshold
55+
56+
Args:
57+
threshold - the threshold to check (must be non-negative)
58+
"""
59+
def __init__(self, monitor_mode: MonitorMode, threshold: float = 0.0):
60+
super(RelativeThresholdChecker, self).__init__(monitor_mode, threshold)
61+
62+
def _is_new_value_lower(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
63+
return (old_value - new_value) / old_value > self.threshold
64+
65+
def _is_new_value_higher(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
66+
return (new_value - old_value) / old_value > self.threshold
67+
68+
def __call__(self, new_value: Union[float, Tensor], old_value: Union[float, Tensor]) -> bool:
69+
if self.monitor_mode == MonitorMode.MIN:
70+
return self._is_new_value_lower(new_value, old_value)
71+
if self.monitor_mode == MonitorMode.MAX:
72+
return self._is_new_value_higher(new_value, old_value)

main.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tests.test_trainer import TestTrainer
33
from tests.test_predictor import TestPredictor
44
from tests.test_callbacks import TestCallbacks
5+
from tests.test_utils import TestUtils
56
import unittest
67

78
import examples.multiple_inputs.train as multiple_inputs_example

tests/test_utils.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
from lpd.enums import MonitorMode
4+
5+
6+
class TestUtils(unittest.TestCase):
7+
8+
def test_absolute_threshold_checker__true(self):
9+
from lpd.utils.threshold_checker import AbsoluteThresholdChecker
10+
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 0.9, 0.799)]:
11+
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold)
12+
with self.subTest():
13+
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value))
14+
15+
max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold)
16+
with self.subTest():
17+
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value))
18+
19+
def test_absolute_threshold_checker__false(self):
20+
from lpd.utils.threshold_checker import AbsoluteThresholdChecker
21+
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 0.9, 0.81)]:
22+
min_checker = AbsoluteThresholdChecker(MonitorMode.MIN, threshold)
23+
with self.subTest():
24+
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value))
25+
26+
max_checker = AbsoluteThresholdChecker(MonitorMode.MAX, threshold)
27+
with self.subTest():
28+
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value))
29+
30+
def test_relative_threshold_checker__true(self):
31+
from lpd.utils.threshold_checker import RelativeThresholdChecker
32+
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.899), (0.1, 120.1, 100.0)]:
33+
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold)
34+
with self.subTest():
35+
self.assertTrue(min_checker(new_value=lower_value, old_value=higher_value))
36+
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold)
37+
with self.subTest():
38+
self.assertTrue(max_checker(new_value=higher_value, old_value=lower_value))
39+
40+
def test_relative_threshold_checker__false(self):
41+
from lpd.utils.threshold_checker import RelativeThresholdChecker
42+
for (threshold, higher_value, lower_value) in [(0.0, 0.9, 0.9), (0.1, 109.99, 100.0)]:
43+
min_checker = RelativeThresholdChecker(MonitorMode.MIN, threshold)
44+
with self.subTest():
45+
self.assertFalse(min_checker(new_value=lower_value, old_value=higher_value))
46+
max_checker = RelativeThresholdChecker(MonitorMode.MAX, threshold)
47+
with self.subTest():
48+
self.assertFalse(max_checker(new_value=higher_value, old_value=lower_value))

0 commit comments

Comments
 (0)