3
3
4
4
import torch .optim as optim
5
5
import torch .nn as nn
6
+
7
+ from lpd .metrics .mock_metric import MockMetric
6
8
from lpd .trainer import Trainer
7
- from lpd .callbacks import StatsPrint , SchedulerStep , LossOptimizerHandler , ModelCheckPoint , CallbackMonitor
9
+ from lpd .callbacks import StatsPrint , SchedulerStep , LossOptimizerHandler , ModelCheckPoint , CallbackMonitor , \
10
+ CallbackContext
8
11
from lpd .extensions .custom_schedulers import KerasDecay
9
12
from lpd .enums import Phase , State , MonitorType , StatsType , MonitorMode
10
13
from lpd .metrics import BinaryAccuracyWithLogits , CategoricalAccuracyWithLogits
11
14
import lpd .utils .torch_utils as tu
12
15
import lpd .utils .general_utils as gu
13
16
import examples .utils as eu
17
+ from lpd .utils .threshold_checker import AbsoluteThresholdChecker
14
18
15
19
16
20
class TestCallbacks (unittest .TestCase ):
@@ -24,4 +28,63 @@ def test_stats_print_validations(self):
24
28
StatsPrint (train_metrics_monitors = CallbackMonitor (monitor_type = MonitorType .METRIC ,
25
29
stats_type = StatsType .TRAIN ,
26
30
monitor_mode = MonitorMode .MAX ,
27
- metric_name = 'Accuracy' ))
31
+ metric_name = 'Accuracy' ))
32
+
33
+ def test_did_improve_gradually (self ):
34
+ gu .seed_all (42 )
35
+
36
+ device = tu .get_gpu_device_if_available ()
37
+
38
+ model = eu .get_basic_model (10 , 10 , 10 ).to (device )
39
+
40
+ loss_func = nn .CrossEntropyLoss ().to (device )
41
+
42
+ optimizer = optim .Adam (model .parameters (), lr = 1e-4 )
43
+
44
+ scheduler = KerasDecay (optimizer , 0.0001 , last_step = - 1 )
45
+
46
+ metrics = MockMetric (0.0 , 'mock_metric' )
47
+
48
+ callbacks = [
49
+ LossOptimizerHandler ()
50
+ ]
51
+
52
+ data_loader = eu .examples_data_generator (10 , 10 , 10 , category_out = True )
53
+ data_loader_steps = 1
54
+
55
+ trainer = Trainer (model = model ,
56
+ device = device ,
57
+ loss_func = loss_func ,
58
+ optimizer = optimizer ,
59
+ scheduler = scheduler ,
60
+ metrics = metrics ,
61
+ train_data_loader = data_loader ,
62
+ val_data_loader = data_loader ,
63
+ train_steps = data_loader_steps ,
64
+ val_steps = data_loader_steps ,
65
+ callbacks = callbacks ,
66
+ name = 'Trainer-Test' )
67
+
68
+ mock_values = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 ]
69
+ threshold = 0.99
70
+ sp = StatsPrint (train_metrics_monitors = CallbackMonitor (monitor_type = MonitorType .METRIC ,
71
+ stats_type = StatsType .TRAIN ,
72
+ monitor_mode = MonitorMode .MAX ,
73
+ threshold_checker = AbsoluteThresholdChecker (MonitorMode .MAX ,
74
+ threshold ),
75
+ metric_name = 'mock_metric' ))
76
+
77
+ trainer .train (1 ) # IMPROVE inf TO 0.0
78
+ res = sp .train_metrics_monitors [0 ].track (CallbackContext (trainer ))
79
+ assert res .did_improve
80
+
81
+ for mock_value in mock_values :
82
+ metrics .set_mock_value (mock_value )
83
+ trainer .train (1 )
84
+ res = sp .train_metrics_monitors [0 ].track (CallbackContext (trainer ))
85
+ assert not res .did_improve
86
+
87
+ metrics .set_mock_value (1.0 ) # IMPROVE 0.0 TO 1.0 (> 0.99)
88
+ trainer .train (1 )
89
+ res = sp .train_metrics_monitors [0 ].track (CallbackContext (trainer ))
90
+ assert res .did_improve
0 commit comments