Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] Add fault tolerance variant to the training data ingest benchmark #50399

Merged
merged 33 commits into from
Feb 13, 2025

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Feb 10, 2025

Summary

Adds a .skip_training.fault_tolerance variant to the image classification training ingest release test which kills a node every N seconds and tests the worker recovery.

Also add the ability to stress-test concurrent multi-dataset execution with training ingest (training and validation datasets) by performing validation during the training epoch every N steps. This is not enabled because it is too unperformant and will cause the test to run for too long. This will be addressed in a follow-up PR.

Also updates the training benchmark script to load training state properly and do mid-epoch resumption by skipping batches all the way up to the batch that corresponds to the latest checkpoint.

Adds the following metrics:

  • checkpoint/download: Time spent downloading the checkpoint from storage to local.
  • checkpoint/load: Time spent loading the checkpoint from local.
  • train/iter_skip_batch: Time spent skipping batches upon restoration to do "mid-epoch resumption."
  • checkpoint/restoration_time: Extra time spent on restoration, which is just the sum of the above 3.

Here's what the fault tolerance test does:

  • Starts a chaos killer which kills a node every ~480 seconds, killing up to 2 nodes across the entire job.
  • Runs training with max_failures=4 (with 2 extra failures than needed as a buffer).

Comment on lines 60 to 73
if self.benchmark_config.validate_every_n_steps > 0:
# TODO: This is just hard-coded for now. Maybe move this to be a configuration.
# Maybe move this to the RayDataLoaderFactory.
cpus_to_exclude = 16
train_ds.context.execution_options.exclude_resources = (
train_ds.context.execution_options.exclude_resources.add(
ray.data.ExecutionResources(cpu=cpus_to_exclude)
)
)
logger.info(
f"[Dataloader] Reserving {cpus_to_exclude} CPUs for validation "
"that happens concurrently with training every "
f"{self.benchmark_config.validate_every_n_steps} steps. "
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raulchen Is this a reasonable way to handle multi-dataset?

Comment on lines +280 to +281
if restoration_time > 0:
metrics["checkpoint/restoration_time"] = restoration_time
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We should track "restoration time" in Ray Train by default. This time does not include process group re-init, actor startup, etc. Would be good to sum these two together to get the restoration startup cost.

Comment on lines 90 to 94
# Skip through batches if we restored to a middle of the epoch.
# TODO: Compare this baseline to the data checkpointing approach once we have it.
for _ in range(self._train_batch_idx):
with self._metrics["train/iter_skip_batch"].timer():
self.get_next_batch(train_dataloader)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping an accurate total for the time spent "skipping" batches is a little tricky because this step could take a long time (ex: if training was killed previously at the second to last batch). Then, a node could get killed in this time, and the time spent skipping batches would be lost from the metrics, which are only snapshotted every checkpoint.

May want to checkpoint the metrics separately from the model if we want to track this better.

Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean!

Comment on lines +52 to +61
download_start = time.perf_counter()
checkpoint.to_directory(temp_checkpoint_dir)
download_time = time.perf_counter() - download_start

load_start = time.perf_counter()
self.load_checkpoint(temp_checkpoint_dir)
load_time = time.perf_counter() - load_start

self._metrics["checkpoint/download"].add(download_time)
self._metrics["checkpoint/load"].add(load_time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use same context manager pattern as done later with the batch iteration?

e.g.

with self._metrics["checkpoint/download"].timer():
    checkpoint.to_directory(temp_checkpoint_dir)


with self._metrics["checkpoint/load"].timer():
    self.load_checkpoint(temp_checkpoint_dir)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_checkpoint will load in the snapshot of the metrics, which will overwrite these guys. So I need to save it separately then add it to the metrics later.

Comment on lines 184 to +186
with self._metrics["validation/step"].timer():
with torch.no_grad():
out = self.model(input_batch)
loss = self.loss_fn(out, labels)
total_loss += loss
num_rows += len(labels)
self._metrics["validation/rows_processed"].add(len(labels))
if not self.benchmark_config.skip_validation_step:
total_loss += self.validate_step(input_batch, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inverse ordering of timer and condition? Or do you intentionally want to include this as well like you mentioned in the other comment.

            if not self.benchmark_config.skip_validation_step:
                with self._metrics["validation/step"].timer():
                    total_loss += self.validate_step(input_batch, labels)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's easier to parse the output if this metric always exists and is just 0 rather than existing conditionally.

# which includes downloading the checkpoint, loading the checkpoint,
# and skipping through batches that were already processed.
restoration_time = (
self._metrics["checkpoint/download"].get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use constants for all keys to make it safe against typos.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this in a followup.

@justinvyu justinvyu enabled auto-merge (squash) February 13, 2025 05:58
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Feb 13, 2025
@justinvyu justinvyu merged commit 68c0ead into ray-project:master Feb 13, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants