Skip to content

Commit 61b0f29

Browse files
authoredMar 16, 2023
Merge pull request #9 from RoySadaka/threshold_checker_improvement
Threshold checker improvement
2 parents 566bf99 + cd54e1d commit 61b0f29

File tree

4 files changed

+91
-16
lines changed

4 files changed

+91
-16
lines changed
 

‎lpd/callbacks/callback_monitor.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ def __init__(self, monitor_type: MonitorType,
3232
self.monitor_mode = monitor_mode
3333
self.threshold_checker = AbsoluteThresholdChecker(monitor_mode) if threshold_checker is None else threshold_checker
3434
self.metric_name = metric_name
35-
self.minimum = None
36-
self.maximum = None
35+
self.best = None
3736
self.previous = None
3837
self.description = self._get_description()
3938
self._track_invoked = False
@@ -45,7 +44,7 @@ def _get_description(self):
4544
return desc
4645

4746
def _get_best(self):
48-
return self.minimum if self.monitor_mode == MonitorMode.MIN else self.maximum
47+
return self.best
4948

5049
def track(self, callback_context: CallbackContext) -> CallbackMonitorResult:
5150
c = callback_context #READABILITY DOWN THE ROAD
@@ -70,8 +69,10 @@ def track(self, callback_context: CallbackContext) -> CallbackMonitorResult:
7069
value_to_consider = metrics_to_consider[self.metric_name]
7170

7271
if not self._track_invoked:
73-
self.minimum = -torch.log(torch.zeros_like(value_to_consider)) # [[inf,...,inf]]
74-
self.maximum = torch.log(torch.zeros_like(value_to_consider)) # [[-inf,...,-inf]]
72+
if self.monitor_mode == MonitorMode.MIN:
73+
self.best = -torch.log(torch.zeros_like(value_to_consider)) # [[inf,...,inf]]
74+
elif self.monitor_mode == MonitorMode.MAX:
75+
self.best = torch.log(torch.zeros_like(value_to_consider)) # [[-inf,...,-inf]]
7576
self.previous = self._get_best()
7677
self._track_invoked = True
7778

@@ -81,19 +82,18 @@ def track(self, callback_context: CallbackContext) -> CallbackMonitorResult:
8182
change_from_previous = value_to_consider - self.previous
8283
curr_best = self._get_best()
8384
change_from_best = value_to_consider - curr_best
84-
self.minimum = torch.min(self.minimum, value_to_consider)
85-
self.maximum = torch.max(self.maximum, value_to_consider)
8685
curr_previous = self.previous
8786
self.previous = value_to_consider
8887
did_improve = False # UNLESS SAID OTHERWISE
89-
new_best = self._get_best()
88+
new_best = curr_best # UNLESS SAID OTHERWISE
9089
name = self.metric_name if self.metric_name else 'loss'
9190

9291
if len(value_to_consider.shape) == 0 or \
9392
(len(value_to_consider.shape) == 1 and value_to_consider.shape[0] == 1):
9493
if self.threshold_checker(new_value=value_to_consider, old_value=curr_best):
9594
did_improve = True
9695
self.patience_countdown = self.patience
96+
self.best = new_best = value_to_consider
9797
else:
9898
if self.patience != inf:
9999
raise ValueError("[CallbackMonitor] - can't monitor patience for metric that has multiple values")

‎lpd/metrics/mock_metric.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
from lpd.enums import MetricMethod
4+
from lpd.metrics import MetricBase
5+
6+
7+
class MockMetric(MetricBase):
8+
def __init__(self, mock_value, name: str):
9+
super(MockMetric, self).__init__(name=name, metric_method=MetricMethod.LAST)
10+
self.mock_value = mock_value
11+
12+
def __call__(self, y_pred: torch.Tensor, y_true: torch.Tensor):
13+
return torch.FloatTensor([self.mock_value])
14+
15+
def set_mock_value(self, mock_value):
16+
self.mock_value = mock_value

‎lpd/trainer_stats.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,11 @@ def add_value(self, value, count):
2828
elif self.metric_method == MetricMethod.SUM:
2929
self.sum += value
3030
self.count += count
31-
32-
elif self.metric_method == MetricMethod.LAST:
33-
self.sum = value
34-
self.count = 1
3531

3632
self.last = value
3733

3834
def get_value(self):
39-
if self.count == 0:
35+
if self.last is None:
4036
return torch.tensor(0.0)
4137

4238
if self.metric_method == MetricMethod.MEAN:
@@ -46,7 +42,7 @@ def get_value(self):
4642
return self.sum
4743

4844
elif self.metric_method == MetricMethod.LAST:
49-
return self.sum
45+
return self.last
5046

5147
class StatsResult():
5248
def __init__(self, trainer_name, stats):

‎tests/test_callbacks.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33

44
import torch.optim as optim
55
import torch.nn as nn
6+
7+
from lpd.metrics.mock_metric import MockMetric
68
from lpd.trainer import Trainer
7-
from lpd.callbacks import StatsPrint, SchedulerStep, LossOptimizerHandler, ModelCheckPoint, CallbackMonitor
9+
from lpd.callbacks import StatsPrint, SchedulerStep, LossOptimizerHandler, ModelCheckPoint, CallbackMonitor, \
10+
CallbackContext
811
from lpd.extensions.custom_schedulers import KerasDecay
912
from lpd.enums import Phase, State, MonitorType, StatsType, MonitorMode
1013
from lpd.metrics import BinaryAccuracyWithLogits, CategoricalAccuracyWithLogits
1114
import lpd.utils.torch_utils as tu
1215
import lpd.utils.general_utils as gu
1316
import examples.utils as eu
17+
from lpd.utils.threshold_checker import AbsoluteThresholdChecker
1418

1519

1620
class TestCallbacks(unittest.TestCase):
@@ -24,4 +28,63 @@ def test_stats_print_validations(self):
2428
StatsPrint(train_metrics_monitors=CallbackMonitor(monitor_type=MonitorType.METRIC,
2529
stats_type=StatsType.TRAIN,
2630
monitor_mode=MonitorMode.MAX,
27-
metric_name='Accuracy'))
31+
metric_name='Accuracy'))
32+
33+
def test_did_improve_gradually(self):
34+
gu.seed_all(42)
35+
36+
device = tu.get_gpu_device_if_available()
37+
38+
model = eu.get_basic_model(10, 10, 10).to(device)
39+
40+
loss_func = nn.CrossEntropyLoss().to(device)
41+
42+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
43+
44+
scheduler = KerasDecay(optimizer, 0.0001, last_step=-1)
45+
46+
metrics = MockMetric(0.0, 'mock_metric')
47+
48+
callbacks = [
49+
LossOptimizerHandler()
50+
]
51+
52+
data_loader = eu.examples_data_generator(10, 10, 10, category_out=True)
53+
data_loader_steps = 1
54+
55+
trainer = Trainer(model=model,
56+
device=device,
57+
loss_func=loss_func,
58+
optimizer=optimizer,
59+
scheduler=scheduler,
60+
metrics=metrics,
61+
train_data_loader=data_loader,
62+
val_data_loader=data_loader,
63+
train_steps=data_loader_steps,
64+
val_steps=data_loader_steps,
65+
callbacks=callbacks,
66+
name='Trainer-Test')
67+
68+
mock_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
69+
threshold = 0.99
70+
sp = StatsPrint(train_metrics_monitors=CallbackMonitor(monitor_type=MonitorType.METRIC,
71+
stats_type=StatsType.TRAIN,
72+
monitor_mode=MonitorMode.MAX,
73+
threshold_checker=AbsoluteThresholdChecker(MonitorMode.MAX,
74+
threshold),
75+
metric_name='mock_metric'))
76+
77+
trainer.train(1) # IMPROVE inf TO 0.0
78+
res = sp.train_metrics_monitors[0].track(CallbackContext(trainer))
79+
assert res.did_improve
80+
81+
for mock_value in mock_values:
82+
metrics.set_mock_value(mock_value)
83+
trainer.train(1)
84+
res = sp.train_metrics_monitors[0].track(CallbackContext(trainer))
85+
assert not res.did_improve
86+
87+
metrics.set_mock_value(1.0) # IMPROVE 0.0 TO 1.0 (> 0.99)
88+
trainer.train(1)
89+
res = sp.train_metrics_monitors[0].track(CallbackContext(trainer))
90+
assert res.did_improve

0 commit comments

Comments
 (0)
Please sign in to comment.