-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Add fault tolerance variant to the training data ingest benchmark #50399
[train] Add fault tolerance variant to the training data ingest benchmark #50399
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
if self.benchmark_config.validate_every_n_steps > 0: | ||
# TODO: This is just hard-coded for now. Maybe move this to be a configuration. | ||
# Maybe move this to the RayDataLoaderFactory. | ||
cpus_to_exclude = 16 | ||
train_ds.context.execution_options.exclude_resources = ( | ||
train_ds.context.execution_options.exclude_resources.add( | ||
ray.data.ExecutionResources(cpu=cpus_to_exclude) | ||
) | ||
) | ||
logger.info( | ||
f"[Dataloader] Reserving {cpus_to_exclude} CPUs for validation " | ||
"that happens concurrently with training every " | ||
f"{self.benchmark_config.validate_every_n_steps} steps. " | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raulchen Is this a reasonable way to handle multi-dataset?
if restoration_time > 0: | ||
metrics["checkpoint/restoration_time"] = restoration_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We should track "restoration time" in Ray Train by default. This time does not include process group re-init, actor startup, etc. Would be good to sum these two together to get the restoration startup cost.
…ase_test/fault_tolerance
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…ase_test/fault_tolerance
# Skip through batches if we restored to a middle of the epoch. | ||
# TODO: Compare this baseline to the data checkpointing approach once we have it. | ||
for _ in range(self._train_batch_idx): | ||
with self._metrics["train/iter_skip_batch"].timer(): | ||
self.get_next_batch(train_dataloader) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping an accurate total for the time spent "skipping" batches is a little tricky because this step could take a long time (ex: if training was killed previously at the second to last batch). Then, a node could get killed in this time, and the time spent skipping batches would be lost from the metrics, which are only snapshotted every checkpoint.
May want to checkpoint the metrics separately from the model if we want to track this better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean!
download_start = time.perf_counter() | ||
checkpoint.to_directory(temp_checkpoint_dir) | ||
download_time = time.perf_counter() - download_start | ||
|
||
load_start = time.perf_counter() | ||
self.load_checkpoint(temp_checkpoint_dir) | ||
load_time = time.perf_counter() - load_start | ||
|
||
self._metrics["checkpoint/download"].add(download_time) | ||
self._metrics["checkpoint/load"].add(load_time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use same context manager pattern as done later with the batch iteration?
e.g.
with self._metrics["checkpoint/download"].timer():
checkpoint.to_directory(temp_checkpoint_dir)
with self._metrics["checkpoint/load"].timer():
self.load_checkpoint(temp_checkpoint_dir)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_checkpoint
will load in the snapshot of the metrics, which will overwrite these guys. So I need to save it separately then add it to the metrics later.
with self._metrics["validation/step"].timer(): | ||
with torch.no_grad(): | ||
out = self.model(input_batch) | ||
loss = self.loss_fn(out, labels) | ||
total_loss += loss | ||
num_rows += len(labels) | ||
self._metrics["validation/rows_processed"].add(len(labels)) | ||
if not self.benchmark_config.skip_validation_step: | ||
total_loss += self.validate_step(input_batch, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inverse ordering of timer and condition? Or do you intentionally want to include this as well like you mentioned in the other comment.
if not self.benchmark_config.skip_validation_step:
with self._metrics["validation/step"].timer():
total_loss += self.validate_step(input_batch, labels)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's easier to parse the output if this metric always exists and is just 0 rather than existing conditionally.
# which includes downloading the checkpoint, loading the checkpoint, | ||
# and skipping through batches that were already processed. | ||
restoration_time = ( | ||
self._metrics["checkpoint/download"].get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use constants for all keys to make it safe against typos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do this in a followup.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
… dies Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…ase_test/fault_tolerance
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…ase_test/fault_tolerance
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Summary
Adds a
.skip_training.fault_tolerance
variant to the image classification training ingest release test which kills a node every N seconds and tests the worker recovery.Also add the ability to stress-test concurrent multi-dataset execution with training ingest (training and validation datasets) by performing validation during the training epoch every N steps. This is not enabled because it is too unperformant and will cause the test to run for too long. This will be addressed in a follow-up PR.
Also updates the training benchmark script to load training state properly and do mid-epoch resumption by skipping batches all the way up to the batch that corresponds to the latest checkpoint.
Adds the following metrics:
checkpoint/download
: Time spent downloading the checkpoint from storage to local.checkpoint/load
: Time spent loading the checkpoint from local.train/iter_skip_batch
: Time spent skipping batches upon restoration to do "mid-epoch resumption."checkpoint/restoration_time
: Extra time spent on restoration, which is just the sum of the above 3.Here's what the fault tolerance test does:
max_failures=4
(with 2 extra failures than needed as a buffer).