From 58159c5fd7c80c7dc43a62f9df714baf7d82eadb Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 2 Feb 2025 23:17:14 +0530 Subject: [PATCH] adding mem_fraction 0.80 for jax workfloads to resolve OOM of certain worklods --- docker/Dockerfile | 1 - submission_runner.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 07375dd92..76bc5cfe0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -92,7 +92,6 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull -RUN pip install wandb # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh diff --git a/submission_runner.py b/submission_runner.py index d2dcb03ac..2acc9d33c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -693,12 +693,21 @@ def main(_): # Prevent OOM on librispeech conformer. base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == 'librispeech_conformer': - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + + if base_workload == [ + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb' + ]: + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR,