diff --git a/megatron/training/training.py b/megatron/training/training.py index f564d90d98..1a92add3c6 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1999,6 +1999,7 @@ def post_training_step_callbacks( # Straggler detector. if iteration % args.log_interval == 0 and args.log_straggler: + # Use FLOPs accumulated since last log event and then reset the counter stimer.report(num_floating_point_operations_since_last_log_event, args.log_interval) num_floating_point_operations_since_last_log_event = 0.0 @@ -2038,6 +2039,9 @@ def post_training_step_callbacks( if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0: gc.collect() + # Return updated FLOPs accumulator so caller can persist the reset + return num_floating_point_operations_since_last_log_event + def checkpoint_and_decide_exit( model, @@ -2516,8 +2520,9 @@ def get_e2e_base_metrics(): energy_monitor.resume() # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). - # Some of these only happen at specific iterations. - post_training_step_callbacks( + # Some of these only happen at specific iterations. Capture updated FLOPs accumulator + # (it is reset inside the callback after logging). + num_floating_point_operations_since_last_log_event = post_training_step_callbacks( model, optimizer, opt_param_scheduler,