Skip to content

Commit 9eb1bea

Browse files
Copilotnjzjz
andcommitted
feat(training): add comprehensive NaN detection with tests and validation
Co-authored-by: njzjz <[email protected]>
1 parent ef431a1 commit 9eb1bea

File tree

7 files changed

+1046
-508
lines changed

7 files changed

+1046
-508
lines changed

deepmd/pd/train/training.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from deepmd.utils.data import (
7676
DataRequirementItem,
7777
)
78+
from deepmd.utils.nan_detector import (
79+
check_loss_nan,
80+
)
7881
from deepmd.utils.path import (
7982
DPH5Path,
8083
)
@@ -951,6 +954,20 @@ def log_loss_valid(_task_key="Default"):
951954
fout, display_step_id, cur_lr, train_results, valid_results
952955
)
953956

957+
# Check for NaN in loss values before saving checkpoint
958+
# Loss values are already on CPU at this point for display/logging
959+
if self.rank == 0:
960+
if not self.multi_task:
961+
check_loss_nan(display_step_id, train_results)
962+
if valid_results:
963+
check_loss_nan(display_step_id, valid_results)
964+
else:
965+
for task_key in train_results:
966+
if train_results[task_key]:
967+
check_loss_nan(display_step_id, train_results[task_key])
968+
if valid_results[task_key]:
969+
check_loss_nan(display_step_id, valid_results[task_key])
970+
954971
if (
955972
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
956973
or (_step_id + 1) == self.num_steps

deepmd/pt/train/training.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from deepmd.utils.data import (
7676
DataRequirementItem,
7777
)
78+
from deepmd.utils.nan_detector import (
79+
check_loss_nan,
80+
)
7881

7982
if torch.__version__.startswith("2"):
8083
import torch._dynamo
@@ -1070,6 +1073,20 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
10701073
fout, display_step_id, cur_lr, train_results, valid_results
10711074
)
10721075

1076+
# Check for NaN in loss values before saving checkpoint
1077+
# Loss values are already on CPU at this point for display/logging
1078+
if self.rank == 0:
1079+
if not self.multi_task:
1080+
check_loss_nan(display_step_id, train_results)
1081+
if valid_results:
1082+
check_loss_nan(display_step_id, valid_results)
1083+
else:
1084+
for task_key in train_results:
1085+
if train_results[task_key]:
1086+
check_loss_nan(display_step_id, train_results[task_key])
1087+
if valid_results[task_key]:
1088+
check_loss_nan(display_step_id, valid_results[task_key])
1089+
10731090
if (
10741091
(
10751092
(display_step_id) % self.save_freq == 0

deepmd/tf/train/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
from deepmd.utils.data import (
6161
DataRequirementItem,
6262
)
63+
from deepmd.utils.nan_detector import (
64+
check_loss_nan,
65+
)
6366

6467
log = logging.getLogger(__name__)
6568

@@ -684,6 +687,13 @@ def valid_on_the_fly(
684687

685688
cur_batch = self.cur_batch
686689
current_lr = run_sess(self.sess, self.learning_rate)
690+
691+
# Check for NaN in loss values before writing to file and saving checkpoint
692+
# Loss values are already on CPU at this point
693+
check_loss_nan(cur_batch, train_results)
694+
if valid_results is not None:
695+
check_loss_nan(cur_batch, valid_results)
696+
687697
if print_header:
688698
self.print_header(fp, train_results, valid_results)
689699
self.print_on_training(

deepmd/utils/nan_detector.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Utilities for detecting NaN values in loss during training."""
3+
4+
import logging
5+
import math
6+
from typing import (
7+
Any,
8+
)
9+
10+
import numpy as np
11+
12+
log = logging.getLogger(__name__)
13+
14+
15+
class LossNaNError(Exception):
16+
"""Exception raised when NaN is detected in loss during training."""
17+
18+
def __init__(self, step: int, loss_dict: dict[str, Any]) -> None:
19+
"""Initialize the exception.
20+
21+
Parameters
22+
----------
23+
step : int
24+
The training step where NaN was detected
25+
loss_dict : dict[str, Any]
26+
Dictionary containing the loss values where NaN was found
27+
"""
28+
self.step = step
29+
self.loss_dict = loss_dict
30+
super().__init__(self._format_message())
31+
32+
def _format_message(self) -> str:
33+
"""Format the error message."""
34+
nan_losses = []
35+
for key, value in self.loss_dict.items():
36+
if self._is_nan(value):
37+
nan_losses.append(f"{key}={value}")
38+
39+
message = (
40+
f"NaN detected in loss at training step {self.step}. "
41+
f"Training stopped to prevent wasting time with corrupted parameters. "
42+
f"NaN values found in: {', '.join(nan_losses)}. "
43+
f"This typically indicates unstable training conditions such as "
44+
f"learning rate too high, poor data quality, or numerical instability."
45+
)
46+
return message
47+
48+
@staticmethod
49+
def _is_nan(value: Any) -> bool:
50+
"""Check if a value is NaN."""
51+
if value is None:
52+
return False
53+
try:
54+
# Handle various tensor types and Python scalars
55+
if hasattr(value, "item"):
56+
# PyTorch/TensorFlow/PaddlePaddle tensor
57+
return math.isnan(value.item())
58+
elif isinstance(value, (int, float)):
59+
# Python scalar
60+
return math.isnan(value)
61+
elif isinstance(value, np.ndarray):
62+
# NumPy array
63+
return np.isnan(value).any()
64+
else:
65+
# Try to convert to float and check
66+
return math.isnan(float(value))
67+
except (TypeError, ValueError):
68+
# If we can't convert to float, assume it's not NaN
69+
return False
70+
71+
72+
def check_loss_nan(step: int, loss_dict: dict[str, Any]) -> None:
73+
"""Check if any loss values contain NaN and raise an exception if found.
74+
75+
This function is designed to be called during training after loss values
76+
are computed and available on CPU, typically during the logging/display phase.
77+
78+
Parameters
79+
----------
80+
step : int
81+
Current training step
82+
loss_dict : dict[str, Any]
83+
Dictionary containing loss values to check for NaN
84+
85+
Raises
86+
------
87+
LossNaNError
88+
If any loss value contains NaN
89+
"""
90+
nan_found = False
91+
for key, value in loss_dict.items():
92+
if LossNaNError._is_nan(value):
93+
nan_found = True
94+
log.error(f"NaN detected in {key} at step {step}: {value}")
95+
96+
if nan_found:
97+
raise LossNaNError(step, loss_dict)
98+
99+
100+
def check_single_loss_nan(step: int, loss_name: str, loss_value: Any) -> None:
101+
"""Check if a single loss value contains NaN and raise an exception if found.
102+
103+
Parameters
104+
----------
105+
step : int
106+
Current training step
107+
loss_name : str
108+
Name/identifier of the loss
109+
loss_value : Any
110+
Loss value to check for NaN
111+
112+
Raises
113+
------
114+
LossNaNError
115+
If the loss value contains NaN
116+
"""
117+
if LossNaNError._is_nan(loss_value):
118+
log.error(f"NaN detected in {loss_name} at step {step}: {loss_value}")
119+
raise LossNaNError(step, {loss_name: loss_value})

0 commit comments

Comments
 (0)