@@ -764,8 +764,6 @@ def step(_step_id: int, task_key: str = "Default") -> None:
764
764
model_pred , loss , more_loss = self .wrapper (
765
765
** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
766
766
)
767
- # Check for NaN in total loss before backward pass to prevent corrupted training
768
- check_total_loss_nan (_step_id + 1 , loss .item ())
769
767
loss .backward ()
770
768
if self .gradient_max_norm > 0.0 :
771
769
torch .nn .utils .clip_grad_norm_ (
@@ -817,8 +815,6 @@ def fake_model() -> dict:
817
815
int (input_dict ["atype" ].shape [- 1 ]),
818
816
learning_rate = pref_lr ,
819
817
)
820
- # Check for NaN in total loss before continuing training
821
- check_total_loss_nan (_step_id + 1 , loss .item ())
822
818
elif isinstance (self .loss , DenoiseLoss ):
823
819
KFOptWrapper = KFOptimizerWrapper (
824
820
self .wrapper ,
@@ -845,8 +841,6 @@ def fake_model() -> dict:
845
841
input_dict ["natoms" ],
846
842
learning_rate = pref_lr ,
847
843
)
848
- # Check for NaN in total loss before continuing training
849
- check_total_loss_nan (_step_id + 1 , loss .item ())
850
844
else :
851
845
raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
852
846
@@ -958,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
958
952
959
953
if not self .multi_task :
960
954
train_results = log_loss_train (loss , more_loss )
955
+ # Check for NaN in total loss using CPU values from lcurve computation
956
+ if self .rank == 0 and "rmse_e" in train_results :
957
+ check_total_loss_nan (display_step_id , train_results ["rmse_e" ])
961
958
valid_results = log_loss_valid ()
962
959
if self .rank == 0 :
963
960
log .info (
@@ -1006,6 +1003,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
1006
1003
loss , more_loss , _task_key = _key
1007
1004
)
1008
1005
valid_results [_key ] = log_loss_valid (_task_key = _key )
1006
+ # Check for NaN in total loss using CPU values from lcurve computation
1007
+ if self .rank == 0 and "rmse_e" in train_results [_key ]:
1008
+ check_total_loss_nan (
1009
+ display_step_id , train_results [_key ]["rmse_e" ]
1010
+ )
1009
1011
if self .rank == 0 :
1010
1012
log .info (
1011
1013
format_training_message_per_task (
0 commit comments