Skip to content

Commit 7a2b41e

Browse files
Copilotnjzjz
andcommitted
fix(training): use 'rmse' key for total loss instead of 'rmse_e' for energy loss
Co-authored-by: njzjz <[email protected]>
1 parent 0852b7c commit 7a2b41e

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

deepmd/pd/train/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,8 @@ def log_loss_valid(_task_key="Default"):
863863
if not self.multi_task:
864864
train_results = log_loss_train(loss, more_loss)
865865
# Check for NaN in total loss using CPU values from lcurve computation
866-
if self.rank == 0 and "rmse_e" in train_results:
867-
check_total_loss_nan(display_step_id, train_results["rmse_e"])
866+
if self.rank == 0 and "rmse" in train_results:
867+
check_total_loss_nan(display_step_id, train_results["rmse"])
868868
valid_results = log_loss_valid()
869869
if self.rank == 0:
870870
log.info(
@@ -907,9 +907,9 @@ def log_loss_valid(_task_key="Default"):
907907
)
908908
valid_results[_key] = log_loss_valid(_task_key=_key)
909909
# Check for NaN in total loss using CPU values from lcurve computation
910-
if self.rank == 0 and "rmse_e" in train_results[_key]:
910+
if self.rank == 0 and "rmse" in train_results[_key]:
911911
check_total_loss_nan(
912-
display_step_id, train_results[_key]["rmse_e"]
912+
display_step_id, train_results[_key]["rmse"]
913913
)
914914
if self.rank == 0:
915915
log.info(

deepmd/pt/train/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -953,8 +953,8 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
953953
if not self.multi_task:
954954
train_results = log_loss_train(loss, more_loss)
955955
# Check for NaN in total loss using CPU values from lcurve computation
956-
if self.rank == 0 and "rmse_e" in train_results:
957-
check_total_loss_nan(display_step_id, train_results["rmse_e"])
956+
if self.rank == 0 and "rmse" in train_results:
957+
check_total_loss_nan(display_step_id, train_results["rmse"])
958958
valid_results = log_loss_valid()
959959
if self.rank == 0:
960960
log.info(
@@ -1004,9 +1004,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
10041004
)
10051005
valid_results[_key] = log_loss_valid(_task_key=_key)
10061006
# Check for NaN in total loss using CPU values from lcurve computation
1007-
if self.rank == 0 and "rmse_e" in train_results[_key]:
1007+
if self.rank == 0 and "rmse" in train_results[_key]:
10081008
check_total_loss_nan(
1009-
display_step_id, train_results[_key]["rmse_e"]
1009+
display_step_id, train_results[_key]["rmse"]
10101010
)
10111011
if self.rank == 0:
10121012
log.info(

deepmd/tf/train/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,8 @@ def valid_on_the_fly(
689689
current_lr = run_sess(self.sess, self.learning_rate)
690690

691691
# Check for NaN in total loss before writing to file and saving checkpoint
692-
# We check the main energy loss component that represents total training loss
693-
check_total_loss_nan(cur_batch, train_results["rmse_e"])
692+
# We check the main total loss component that represents training loss
693+
check_total_loss_nan(cur_batch, train_results["rmse"])
694694

695695
if print_header:
696696
self.print_header(fp, train_results, valid_results)

0 commit comments

Comments
 (0)