Skip to content

Commit 2db72cf

Browse files
committed
feat: log disk space usage info, warn if close to exhaustion
Signed-off-by: Ihar Hrachyshka <[email protected]>
1 parent e94f8ab commit 2db72cf

File tree

2 files changed

+91
-15
lines changed

2 files changed

+91
-15
lines changed

src/instructlab/training/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
self.noise_alpha = noise_alpha
5656
self.tokenizer = tokenizer
5757
self.distributed_framework = distributed_framework
58+
self._last_checkpoint_size: int | None = None
5859
bnb_config = None
5960
if lora_config and lora_config.r > 0 and lora_quant_bits == 4:
6061
# Third Party
@@ -76,6 +77,14 @@ def __init__(
7677
if flash_enabled:
7778
self.base_model_args["attn_implementation"] = "flash_attention_2"
7879

80+
@property
81+
def last_checkpoint_size(self) -> int | None:
82+
return self._last_checkpoint_size
83+
84+
@last_checkpoint_size.setter
85+
def last_checkpoint_size(self, value: int):
86+
self._last_checkpoint_size = value
87+
7988
def _post_model_init(self):
8089
"""Common initialization steps that should happen after model initialization."""
8190
self.reconcile_tokenizer()

src/instructlab/training/utils.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -624,15 +624,29 @@ def get_caller(num_frames=1):
624624
return f"In {file_name}, line {line_number}"
625625

626626

627-
def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
627+
def log_rank_0(
628+
msg, include_caller=False, rank=None, to_print=False, level=logging.INFO
629+
) -> None:
628630
if rank is None:
629631
rank = get_rank() if is_initialized() else 0
630-
if rank <= 0:
631-
if include_caller:
632-
msg = f"{get_caller(num_frames=2)}: {msg}"
633-
if to_print:
634-
print(msg)
635-
else:
632+
if rank > 0:
633+
return
634+
635+
if include_caller:
636+
msg = f"{get_caller(num_frames=2)}: {msg}"
637+
638+
if to_print:
639+
print(msg)
640+
return
641+
642+
match level:
643+
case logging.WARNING:
644+
logger.warning(msg)
645+
case logging.ERROR:
646+
logger.error(msg)
647+
case logging.DEBUG:
648+
logger.debug(msg)
649+
case _:
636650
logger.info(msg)
637651

638652

@@ -673,6 +687,13 @@ def skip_precheck_loops():
673687
accelerator.get_state_dict = old_get_state
674688

675689

690+
def _get_checkpoint_dir(args, samples_seen) -> Path:
691+
subdir = (
692+
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
693+
)
694+
return Path(args.output_dir) / "hf_format" / subdir
695+
696+
676697
def save_hf_format_accelerate(
677698
args,
678699
model,
@@ -681,13 +702,11 @@ def save_hf_format_accelerate(
681702
samples_seen,
682703
is_lora=False,
683704
):
684-
# Build the subdirectory name
685-
subdir = (
686-
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
687-
)
705+
# Build the final output directory path
706+
final_output_dir = _get_checkpoint_dir(args, samples_seen)
688707

689708
log_rank_0(
690-
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
709+
f"\033[93mSaving model in huggingface format at: {final_output_dir}\033[0m",
691710
to_print=True,
692711
)
693712
start = time.time()
@@ -697,9 +716,6 @@ def save_hf_format_accelerate(
697716
else:
698717
convert_dolomite = True
699718

700-
# Build the final output directory path
701-
final_output_dir = Path(args.output_dir) / "hf_format" / subdir
702-
703719
if args.use_dolomite and convert_dolomite:
704720
tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
705721
output_dir = Path(tmpdir.name)
@@ -797,6 +813,48 @@ def set_random_seed(seed):
797813
torch.cuda.manual_seed_all(seed)
798814

799815

816+
def _get_checkpoint_dir_size(checkpoint_dir) -> int:
817+
total = 0
818+
for dirpath, _, filenames in os.walk(checkpoint_dir):
819+
for f in filenames:
820+
fp = os.path.join(dirpath, f)
821+
if os.path.isfile(fp):
822+
total += os.path.getsize(fp)
823+
return total
824+
825+
826+
def check_disk_space_for_next_checkpoint(
827+
model: Model, output_dir: Path, warn_steps_ahead: int = 3
828+
) -> None:
829+
checkpoint_size = model.last_checkpoint_size
830+
if checkpoint_size is None:
831+
# No previous checkpoint size to estimate, do nothing.
832+
return
833+
834+
def _mb_size(num_bytes):
835+
return f"{num_bytes / 1024 / 1024:.2f} MB"
836+
837+
try:
838+
stat = shutil.disk_usage(output_dir)
839+
free_bytes = stat.free
840+
needed_bytes = checkpoint_size * warn_steps_ahead
841+
842+
log_rank_0(
843+
f"Disk space info: free={_mb_size(free_bytes)}, last_checkpoint_size={_mb_size(checkpoint_size)} (output_dir={output_dir})"
844+
)
845+
if free_bytes < needed_bytes:
846+
log_rank_0(
847+
f"Estimated free disk space ({_mb_size(free_bytes)}) is less than the estimated size of the next {warn_steps_ahead} checkpoints ({_mb_size(needed_bytes)}). "
848+
"The next checkpoint(s) may fail due to insufficient disk space.",
849+
level=logging.WARNING,
850+
)
851+
except Exception as e:
852+
log_rank_0(
853+
f"Could not check disk space after checkpoint: {e}",
854+
level=logging.ERROR,
855+
)
856+
857+
800858
def save_checkpoint(
801859
args,
802860
accelerator: Accelerator,
@@ -827,6 +885,15 @@ def save_checkpoint(
827885
samples_seen=samples_seen,
828886
)
829887

888+
# Track checkpoint size and warn if disk space is low
889+
output_dir = Path(args.output_dir)
890+
check_disk_space_for_next_checkpoint(model, output_dir, warn_steps_ahead=3)
891+
892+
if hf_format:
893+
checkpoint_dir = _get_checkpoint_dir(args, samples_seen)
894+
if checkpoint_dir.exists():
895+
model.last_checkpoint_size = _get_checkpoint_dir_size(checkpoint_dir)
896+
830897

831898
def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int):
832899
"""

0 commit comments

Comments
 (0)