Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,7 @@ def post_training_step_callbacks(

# Straggler detector.
if iteration % args.log_interval == 0 and args.log_straggler:
# Use FLOPs accumulated since last log event and then reset the counter
stimer.report(num_floating_point_operations_since_last_log_event, args.log_interval)
num_floating_point_operations_since_last_log_event = 0.0

Expand Down Expand Up @@ -2038,6 +2039,9 @@ def post_training_step_callbacks(
if args.manual_gc_interval != 0 and iteration % args.manual_gc_interval == 0:
gc.collect()

# Return updated FLOPs accumulator so caller can persist the reset
return num_floating_point_operations_since_last_log_event


def checkpoint_and_decide_exit(
model,
Expand Down Expand Up @@ -2516,8 +2520,9 @@ def get_e2e_base_metrics():
energy_monitor.resume()

# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(
# Some of these only happen at specific iterations. Capture updated FLOPs accumulator
# (it is reset inside the callback after logging).
num_floating_point_operations_since_last_log_event = post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
Expand Down