1
+ from lpd .callbacks .callback_monitor_result import CallbackMonitorResult
1
2
from lpd .enums import Phase , State , MonitorType , MonitorMode , StatsType
2
3
from lpd .callbacks .callback_context import CallbackContext
3
4
from typing import Union , List , Optional , Dict
6
7
7
8
from lpd .utils .threshold_checker import ThresholdChecker , AbsoluteThresholdChecker
8
9
9
-
10
10
class CallbackMonitor :
11
11
"""
12
12
Will check if the desired metric improved with support for patience
@@ -15,14 +15,14 @@ class CallbackMonitor:
15
15
(negative number will set to inf)
16
16
monitor_type - e.g lpd.enums.MonitorType.LOSS
17
17
stats_type - e.g lpd.enums.StatsType.VAL
18
- monitor_mode - e.g. lpd.enums.MonitorMode.MIN, min wothh check if the metric decreased, MAX will check for increase
18
+ monitor_mode - e.g. lpd.enums.MonitorMode.MIN, will check if the metric decreased, MonitorMode. MAX will check for increase
19
19
metric_name - in case of monitor_mode=lpd.enums.MonitorMode.METRIC, provide metric_name, otherwise, leave it None
20
20
threshold_checker - to check if the criteria was met, if None, AbsoluteThresholdChecker with threshold=0.0 will be used
21
21
"""
22
- def __init__ (self , monitor_type : MonitorType ,
23
- stats_type : StatsType ,
22
+ def __init__ (self , monitor_type : MonitorType ,
23
+ stats_type : StatsType ,
24
24
monitor_mode : MonitorMode ,
25
- patience : int = None ,
25
+ patience : int = None ,
26
26
metric_name : Optional [str ]= None ,
27
27
threshold_checker : Optional [ThresholdChecker ]= None ):
28
28
self .patience = inf if patience is None or patience < 0 else patience
@@ -32,9 +32,9 @@ def __init__(self, monitor_type: MonitorType,
32
32
self .monitor_mode = monitor_mode
33
33
self .threshold_checker = AbsoluteThresholdChecker (monitor_mode ) if threshold_checker is None else threshold_checker
34
34
self .metric_name = metric_name
35
- self .minimum = torch . tensor ( inf )
36
- self .maximum = torch . tensor ( - inf )
37
- self .previous = self . _get_best ()
35
+ self .minimum = None
36
+ self .maximum = None
37
+ self .previous = None
38
38
self .description = self ._get_description ()
39
39
self ._track_invoked = False
40
40
@@ -47,7 +47,7 @@ def _get_description(self):
47
47
def _get_best (self ):
48
48
return self .minimum if self .monitor_mode == MonitorMode .MIN else self .maximum
49
49
50
- def track (self , callback_context : CallbackContext ):
50
+ def track (self , callback_context : CallbackContext ) -> CallbackMonitorResult :
51
51
c = callback_context #READABILITY DOWN THE ROAD
52
52
53
53
# EXTRACT value_to_consider
@@ -70,8 +70,9 @@ def track(self, callback_context: CallbackContext):
70
70
value_to_consider = metrics_to_consider [self .metric_name ]
71
71
72
72
if not self ._track_invoked :
73
- self .minimum = - torch .log (torch .zeros_like (value_to_consider )) # [[inf,inf,inf,inf]]
74
- self .maximum = torch .log (torch .zeros_like (value_to_consider )) # [[-inf,-inf,-inf,-inf]]
73
+ self .minimum = - torch .log (torch .zeros_like (value_to_consider )) # [[inf,...,inf]]
74
+ self .maximum = torch .log (torch .zeros_like (value_to_consider )) # [[-inf,...,-inf]]
75
+ self .previous = self ._get_best ()
75
76
self ._track_invoked = True
76
77
77
78
@@ -80,13 +81,11 @@ def track(self, callback_context: CallbackContext):
80
81
change_from_previous = value_to_consider - self .previous
81
82
curr_best = self ._get_best ()
82
83
change_from_best = value_to_consider - curr_best
83
- curr_minimum = self .minimum
84
- curr_maximum = self .maximum
85
84
self .minimum = torch .min (self .minimum , value_to_consider )
86
85
self .maximum = torch .max (self .maximum , value_to_consider )
87
86
curr_previous = self .previous
88
87
self .previous = value_to_consider
89
- did_improve = False
88
+ did_improve = False # UNLESS SAID OTHERWISE
90
89
new_best = self ._get_best ()
91
90
name = self .metric_name if self .metric_name else 'loss'
92
91
@@ -109,32 +108,3 @@ def track(self, callback_context: CallbackContext):
109
108
patience_left = self .patience_countdown ,
110
109
description = self .description ,
111
110
name = name )
112
-
113
-
114
- class CallbackMonitorResult ():
115
- def __init__ (self , did_improve : bool ,
116
- new_value : float ,
117
- prev_value : float ,
118
- new_best : float ,
119
- prev_best : float ,
120
- change_from_previous : float ,
121
- change_from_best : float ,
122
- patience_left : int ,
123
- description : str ,
124
- name : str ):
125
- self .name = name
126
- self .did_improve = did_improve
127
- self .new_value = new_value
128
- self .prev_value = prev_value
129
- self .new_best = new_best
130
- self .prev_best = prev_best
131
- self .change_from_previous = change_from_previous
132
- self .change_from_best = change_from_best
133
- self .patience_left = patience_left
134
- self .description = description
135
-
136
- def has_improved (self ):
137
- return self .did_improve
138
-
139
- def has_patience (self ):
140
- return self .patience_left > 0
0 commit comments