From 0354be3cd596a52eb47d5feda6aa1e46637b0503 Mon Sep 17 00:00:00 2001 From: cms42 Date: Tue, 19 Aug 2025 18:47:58 +0800 Subject: [PATCH] fix(training): Reset straggler detector FLOPs accumulator after each log interval Co-authored-by: Li Ruixiao --- megatron/training/training.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index f564d90d98..1a92add3c6 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1999,6 +1999,7 @@ def post_training_step_callbacks( # Straggler detector. if iteration % args.log_interval == 0 and args.log_straggler: + # Use FLOPs accumulated since last log event and then reset the counter stimer.report(num_floating_point_operations_since_last_log_event, args.log_interval) num_floating_point_operations_since_last_log_event = 0.0 @@ -2038,6 +2039,9 @@ def post_training_step_callbacks( if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0: gc.collect() + # Return updated FLOPs accumulator so caller can persist the reset + return num_floating_point_operations_since_last_log_event + def checkpoint_and_decide_exit( model, @@ -2516,8 +2520,9 @@ def get_e2e_base_metrics(): energy_monitor.resume() # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). - # Some of these only happen at specific iterations. - post_training_step_callbacks( + # Some of these only happen at specific iterations. Capture updated FLOPs accumulator + # (it is reset inside the callback after logging). + num_floating_point_operations_since_last_log_event = post_training_step_callbacks( model, optimizer, opt_param_scheduler,