Skip to content

Commit 22cb9ef

Browse files
Copilotnjzjz
andcommitted
fix: revert implib file and clean up redundant test code
Co-authored-by: njzjz <[email protected]>
1 parent 7a2b41e commit 22cb9ef

File tree

2 files changed

+6
-18
lines changed

2 files changed

+6
-18
lines changed

source/tests/common/test_nan_detector.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@ def test_normal_values_pass(self):
2222

2323
# Should not raise any exception
2424
for i, loss_val in enumerate(normal_losses):
25-
try:
26-
check_total_loss_nan(100 + i, loss_val)
27-
except Exception as e:
28-
self.fail(f"Normal values should not raise exception: {e}")
25+
check_total_loss_nan(100 + i, loss_val)
2926

3027
def test_nan_detection_raises_exception(self):
3128
"""Test that NaN values trigger the proper exception."""

source/tests/common/test_nan_integration.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,18 @@ def test_logging_on_nan_detection(self, mock_log):
4848

4949
def test_training_simulation_with_checkpoint_prevention(self):
5050
"""Simulate the training checkpoint scenario to ensure NaN prevents saving."""
51-
52-
def mock_save_checkpoint():
53-
"""Mock function that should not be called when NaN is detected."""
54-
raise AssertionError("Checkpoint should not be saved when NaN is detected!")
55-
5651
# Simulate the training flow: check total loss, then save checkpoint
5752
step_id = 1000
5853
total_loss = float("nan")
5954

60-
# This should raise LossNaNError before checkpoint saving
61-
with self.assertRaises(LossNaNError):
55+
# This should raise LossNaNError, preventing any subsequent checkpoint saving
56+
with self.assertRaises(LossNaNError) as context:
6257
check_total_loss_nan(step_id, total_loss)
63-
# This line should never be reached
64-
mock_save_checkpoint()
6558

6659
# Verify the error contains expected information
67-
try:
68-
check_total_loss_nan(step_id, total_loss)
69-
except LossNaNError as e:
70-
self.assertIn("Training stopped to prevent wasting time", str(e))
71-
self.assertIn("corrupted parameters", str(e))
60+
exception = context.exception
61+
self.assertIn("Training stopped to prevent wasting time", str(exception))
62+
self.assertIn("corrupted parameters", str(exception))
7263

7364
def test_realistic_training_scenario(self):
7465
"""Test a more realistic training scenario with decreasing then NaN loss."""

0 commit comments

Comments
 (0)