Skip to content

Commit

Permalink
adding mem_fraction 0.80 for jax workfloads to resolve OOM of certain…
Browse files Browse the repository at this point in the history
… worklods
  • Loading branch information
init-22 committed Feb 2, 2025
1 parent d7eebf8 commit 58159c5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 0 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]'

RUN cd /algorithmic-efficiency && git fetch origin
RUN cd /algorithmic-efficiency && git pull
RUN pip install wandb

# Todo: remove this, this is temporary for developing
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
Expand Down
13 changes: 11 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,12 +693,21 @@ def main(_):

# Prevent OOM on librispeech conformer.
base_workload = workloads.get_base_workload_name(FLAGS.workload)
if base_workload == 'librispeech_conformer':
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85'

if base_workload == [
'librispeech_conformer',
'librispeech_deepspeech',
'imagenet_vit',
'criteo1tb'
]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'

if FLAGS.set_pytorch_max_split_size:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer':
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
BASE_WORKLOADS_DIR,
Expand Down

0 comments on commit 58159c5

Please sign in to comment.