Skip to content

Commit 0852b7c

Browse files
Copilotnjzjz
andcommitted
fix(training): optimize NaN detection based on feedback - use lcurve CPU values and fixed loss keys
Co-authored-by: njzjz <[email protected]>
1 parent 5a22dfc commit 0852b7c

File tree

4 files changed

+603
-528
lines changed

4 files changed

+603
-528
lines changed

deepmd/pd/train/training.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,8 +781,6 @@ def step(_step_id, task_key="Default") -> None:
781781
label=label_dict,
782782
task_key=task_key,
783783
)
784-
# Check for NaN in total loss before backward pass to prevent corrupted training
785-
check_total_loss_nan(_step_id + 1, loss.item())
786784

787785
with nvprof_context(enable_profiling, "Backward pass"):
788786
loss.backward()
@@ -864,6 +862,9 @@ def log_loss_valid(_task_key="Default"):
864862

865863
if not self.multi_task:
866864
train_results = log_loss_train(loss, more_loss)
865+
# Check for NaN in total loss using CPU values from lcurve computation
866+
if self.rank == 0 and "rmse_e" in train_results:
867+
check_total_loss_nan(display_step_id, train_results["rmse_e"])
867868
valid_results = log_loss_valid()
868869
if self.rank == 0:
869870
log.info(
@@ -905,6 +906,11 @@ def log_loss_valid(_task_key="Default"):
905906
loss, more_loss, _task_key=_key
906907
)
907908
valid_results[_key] = log_loss_valid(_task_key=_key)
909+
# Check for NaN in total loss using CPU values from lcurve computation
910+
if self.rank == 0 and "rmse_e" in train_results[_key]:
911+
check_total_loss_nan(
912+
display_step_id, train_results[_key]["rmse_e"]
913+
)
908914
if self.rank == 0:
909915
log.info(
910916
format_training_message_per_task(

deepmd/pt/train/training.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,6 @@ def step(_step_id: int, task_key: str = "Default") -> None:
764764
model_pred, loss, more_loss = self.wrapper(
765765
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
766766
)
767-
# Check for NaN in total loss before backward pass to prevent corrupted training
768-
check_total_loss_nan(_step_id + 1, loss.item())
769767
loss.backward()
770768
if self.gradient_max_norm > 0.0:
771769
torch.nn.utils.clip_grad_norm_(
@@ -817,8 +815,6 @@ def fake_model() -> dict:
817815
int(input_dict["atype"].shape[-1]),
818816
learning_rate=pref_lr,
819817
)
820-
# Check for NaN in total loss before continuing training
821-
check_total_loss_nan(_step_id + 1, loss.item())
822818
elif isinstance(self.loss, DenoiseLoss):
823819
KFOptWrapper = KFOptimizerWrapper(
824820
self.wrapper,
@@ -845,8 +841,6 @@ def fake_model() -> dict:
845841
input_dict["natoms"],
846842
learning_rate=pref_lr,
847843
)
848-
# Check for NaN in total loss before continuing training
849-
check_total_loss_nan(_step_id + 1, loss.item())
850844
else:
851845
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
852846

@@ -958,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
958952

959953
if not self.multi_task:
960954
train_results = log_loss_train(loss, more_loss)
955+
# Check for NaN in total loss using CPU values from lcurve computation
956+
if self.rank == 0 and "rmse_e" in train_results:
957+
check_total_loss_nan(display_step_id, train_results["rmse_e"])
961958
valid_results = log_loss_valid()
962959
if self.rank == 0:
963960
log.info(
@@ -1006,6 +1003,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
10061003
loss, more_loss, _task_key=_key
10071004
)
10081005
valid_results[_key] = log_loss_valid(_task_key=_key)
1006+
# Check for NaN in total loss using CPU values from lcurve computation
1007+
if self.rank == 0 and "rmse_e" in train_results[_key]:
1008+
check_total_loss_nan(
1009+
display_step_id, train_results[_key]["rmse_e"]
1010+
)
10091011
if self.rank == 0:
10101012
log.info(
10111013
format_training_message_per_task(

deepmd/tf/train/trainer.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -689,18 +689,8 @@ def valid_on_the_fly(
689689
current_lr = run_sess(self.sess, self.learning_rate)
690690

691691
# Check for NaN in total loss before writing to file and saving checkpoint
692-
# We check the main loss component that represents total training loss
693-
if train_results:
694-
# Look for the main loss key (typically the first loss component)
695-
main_loss_key = next(iter(train_results.keys())) if train_results else None
696-
if main_loss_key and main_loss_key in train_results:
697-
check_total_loss_nan(cur_batch, train_results[main_loss_key])
698-
699-
if valid_results:
700-
# Check validation loss as well for consistency
701-
main_loss_key = next(iter(valid_results.keys())) if valid_results else None
702-
if main_loss_key and main_loss_key in valid_results:
703-
check_total_loss_nan(cur_batch, valid_results[main_loss_key])
692+
# We check the main energy loss component that represents total training loss
693+
check_total_loss_nan(cur_batch, train_results["rmse_e"])
704694

705695
if print_header:
706696
self.print_header(fp, train_results, valid_results)

0 commit comments

Comments
 (0)