|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +"""Utilities for detecting NaN values in loss during training.""" |
| 3 | + |
| 4 | +import logging |
| 5 | +import math |
| 6 | +from typing import ( |
| 7 | + Any, |
| 8 | +) |
| 9 | + |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +log = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +class LossNaNError(Exception): |
| 16 | + """Exception raised when NaN is detected in loss during training.""" |
| 17 | + |
| 18 | + def __init__(self, step: int, loss_dict: dict[str, Any]) -> None: |
| 19 | + """Initialize the exception. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + step : int |
| 24 | + The training step where NaN was detected |
| 25 | + loss_dict : dict[str, Any] |
| 26 | + Dictionary containing the loss values where NaN was found |
| 27 | + """ |
| 28 | + self.step = step |
| 29 | + self.loss_dict = loss_dict |
| 30 | + super().__init__(self._format_message()) |
| 31 | + |
| 32 | + def _format_message(self) -> str: |
| 33 | + """Format the error message.""" |
| 34 | + nan_losses = [] |
| 35 | + for key, value in self.loss_dict.items(): |
| 36 | + if self._is_nan(value): |
| 37 | + nan_losses.append(f"{key}={value}") |
| 38 | + |
| 39 | + message = ( |
| 40 | + f"NaN detected in loss at training step {self.step}. " |
| 41 | + f"Training stopped to prevent wasting time with corrupted parameters. " |
| 42 | + f"NaN values found in: {', '.join(nan_losses)}. " |
| 43 | + f"This typically indicates unstable training conditions such as " |
| 44 | + f"learning rate too high, poor data quality, or numerical instability." |
| 45 | + ) |
| 46 | + return message |
| 47 | + |
| 48 | + @staticmethod |
| 49 | + def _is_nan(value: Any) -> bool: |
| 50 | + """Check if a value is NaN.""" |
| 51 | + if value is None: |
| 52 | + return False |
| 53 | + try: |
| 54 | + # Handle various tensor types and Python scalars |
| 55 | + if hasattr(value, "item"): |
| 56 | + # PyTorch/TensorFlow/PaddlePaddle tensor |
| 57 | + return math.isnan(value.item()) |
| 58 | + elif isinstance(value, (int, float)): |
| 59 | + # Python scalar |
| 60 | + return math.isnan(value) |
| 61 | + elif isinstance(value, np.ndarray): |
| 62 | + # NumPy array |
| 63 | + return np.isnan(value).any() |
| 64 | + else: |
| 65 | + # Try to convert to float and check |
| 66 | + return math.isnan(float(value)) |
| 67 | + except (TypeError, ValueError): |
| 68 | + # If we can't convert to float, assume it's not NaN |
| 69 | + return False |
| 70 | + |
| 71 | + |
| 72 | +def check_loss_nan(step: int, loss_dict: dict[str, Any]) -> None: |
| 73 | + """Check if any loss values contain NaN and raise an exception if found. |
| 74 | +
|
| 75 | + This function is designed to be called during training after loss values |
| 76 | + are computed and available on CPU, typically during the logging/display phase. |
| 77 | +
|
| 78 | + Parameters |
| 79 | + ---------- |
| 80 | + step : int |
| 81 | + Current training step |
| 82 | + loss_dict : dict[str, Any] |
| 83 | + Dictionary containing loss values to check for NaN |
| 84 | +
|
| 85 | + Raises |
| 86 | + ------ |
| 87 | + LossNaNError |
| 88 | + If any loss value contains NaN |
| 89 | + """ |
| 90 | + nan_found = False |
| 91 | + for key, value in loss_dict.items(): |
| 92 | + if LossNaNError._is_nan(value): |
| 93 | + nan_found = True |
| 94 | + log.error(f"NaN detected in {key} at step {step}: {value}") |
| 95 | + |
| 96 | + if nan_found: |
| 97 | + raise LossNaNError(step, loss_dict) |
| 98 | + |
| 99 | + |
| 100 | +def check_single_loss_nan(step: int, loss_name: str, loss_value: Any) -> None: |
| 101 | + """Check if a single loss value contains NaN and raise an exception if found. |
| 102 | +
|
| 103 | + Parameters |
| 104 | + ---------- |
| 105 | + step : int |
| 106 | + Current training step |
| 107 | + loss_name : str |
| 108 | + Name/identifier of the loss |
| 109 | + loss_value : Any |
| 110 | + Loss value to check for NaN |
| 111 | +
|
| 112 | + Raises |
| 113 | + ------ |
| 114 | + LossNaNError |
| 115 | + If the loss value contains NaN |
| 116 | + """ |
| 117 | + if LossNaNError._is_nan(loss_value): |
| 118 | + log.error(f"NaN detected in {loss_name} at step {step}: {loss_value}") |
| 119 | + raise LossNaNError(step, {loss_name: loss_value}) |
0 commit comments