File tree Expand file tree Collapse file tree 3 files changed +10
-10
lines changed Expand file tree Collapse file tree 3 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -863,8 +863,8 @@ def log_loss_valid(_task_key="Default"):
863
863
if not self .multi_task :
864
864
train_results = log_loss_train (loss , more_loss )
865
865
# Check for NaN in total loss using CPU values from lcurve computation
866
- if self .rank == 0 and "rmse_e " in train_results :
867
- check_total_loss_nan (display_step_id , train_results ["rmse_e " ])
866
+ if self .rank == 0 and "rmse " in train_results :
867
+ check_total_loss_nan (display_step_id , train_results ["rmse " ])
868
868
valid_results = log_loss_valid ()
869
869
if self .rank == 0 :
870
870
log .info (
@@ -907,9 +907,9 @@ def log_loss_valid(_task_key="Default"):
907
907
)
908
908
valid_results [_key ] = log_loss_valid (_task_key = _key )
909
909
# Check for NaN in total loss using CPU values from lcurve computation
910
- if self .rank == 0 and "rmse_e " in train_results [_key ]:
910
+ if self .rank == 0 and "rmse " in train_results [_key ]:
911
911
check_total_loss_nan (
912
- display_step_id , train_results [_key ]["rmse_e " ]
912
+ display_step_id , train_results [_key ]["rmse " ]
913
913
)
914
914
if self .rank == 0 :
915
915
log .info (
Original file line number Diff line number Diff line change @@ -953,8 +953,8 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
953
953
if not self .multi_task :
954
954
train_results = log_loss_train (loss , more_loss )
955
955
# Check for NaN in total loss using CPU values from lcurve computation
956
- if self .rank == 0 and "rmse_e " in train_results :
957
- check_total_loss_nan (display_step_id , train_results ["rmse_e " ])
956
+ if self .rank == 0 and "rmse " in train_results :
957
+ check_total_loss_nan (display_step_id , train_results ["rmse " ])
958
958
valid_results = log_loss_valid ()
959
959
if self .rank == 0 :
960
960
log .info (
@@ -1004,9 +1004,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
1004
1004
)
1005
1005
valid_results [_key ] = log_loss_valid (_task_key = _key )
1006
1006
# Check for NaN in total loss using CPU values from lcurve computation
1007
- if self .rank == 0 and "rmse_e " in train_results [_key ]:
1007
+ if self .rank == 0 and "rmse " in train_results [_key ]:
1008
1008
check_total_loss_nan (
1009
- display_step_id , train_results [_key ]["rmse_e " ]
1009
+ display_step_id , train_results [_key ]["rmse " ]
1010
1010
)
1011
1011
if self .rank == 0 :
1012
1012
log .info (
Original file line number Diff line number Diff line change @@ -689,8 +689,8 @@ def valid_on_the_fly(
689
689
current_lr = run_sess (self .sess , self .learning_rate )
690
690
691
691
# Check for NaN in total loss before writing to file and saving checkpoint
692
- # We check the main energy loss component that represents total training loss
693
- check_total_loss_nan (cur_batch , train_results ["rmse_e " ])
692
+ # We check the main total loss component that represents training loss
693
+ check_total_loss_nan (cur_batch , train_results ["rmse " ])
694
694
695
695
if print_header :
696
696
self .print_header (fp , train_results , valid_results )
You can’t perform that action at this time.
0 commit comments