@@ -48,27 +48,18 @@ def test_logging_on_nan_detection(self, mock_log):
48
48
49
49
def test_training_simulation_with_checkpoint_prevention (self ):
50
50
"""Simulate the training checkpoint scenario to ensure NaN prevents saving."""
51
-
52
- def mock_save_checkpoint ():
53
- """Mock function that should not be called when NaN is detected."""
54
- raise AssertionError ("Checkpoint should not be saved when NaN is detected!" )
55
-
56
51
# Simulate the training flow: check total loss, then save checkpoint
57
52
step_id = 1000
58
53
total_loss = float ("nan" )
59
54
60
- # This should raise LossNaNError before checkpoint saving
61
- with self .assertRaises (LossNaNError ):
55
+ # This should raise LossNaNError, preventing any subsequent checkpoint saving
56
+ with self .assertRaises (LossNaNError ) as context :
62
57
check_total_loss_nan (step_id , total_loss )
63
- # This line should never be reached
64
- mock_save_checkpoint ()
65
58
66
59
# Verify the error contains expected information
67
- try :
68
- check_total_loss_nan (step_id , total_loss )
69
- except LossNaNError as e :
70
- self .assertIn ("Training stopped to prevent wasting time" , str (e ))
71
- self .assertIn ("corrupted parameters" , str (e ))
60
+ exception = context .exception
61
+ self .assertIn ("Training stopped to prevent wasting time" , str (exception ))
62
+ self .assertIn ("corrupted parameters" , str (exception ))
72
63
73
64
def test_realistic_training_scenario (self ):
74
65
"""Test a more realistic training scenario with decreasing then NaN loss."""
0 commit comments