Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clear cuda cache to help with memory leak/creep #1858

Merged
merged 2 commits into from
Aug 26, 2024
Merged

clear cuda cache to help with memory leak/creep #1858

merged 2 commits into from
Aug 26, 2024

Conversation

winglian
Copy link
Collaborator

No description provided.

Comment on lines 1006 to 1007
torch.cuda.empty_cache()
gc.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't gc.collect be first?
In my experience if there is a true leak (hanging references) empty cache cannot not get rid of leaks


If it helps, I can try and debug where the leak is given the training config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I reordered the calls and added some context from reproduction

@winglian
Copy link
Collaborator Author

@chiragjn there were about 40 and 60 steps in
Screenshot 2024-08-25 at 4 00 11 PM
Screenshot 2024-08-25 at 4 02 06 PM

base_model: NousResearch/Meta-Llama-3.1-8B

plugins:
  - axolotl.integrations.liger.LigerPlugin
  - axolotl.integrations.spectrum.SpectrumPlugin

spectrum_top_fraction: 0.5
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_cross_entropy: true
# liger_fused_linear_cross_entropy: true

strict: false


chat_template: llama3

rl: dpo
datasets:
  - path: argilla/distilabel-intel-orca-dpo-pairs
    split: train
    type: llama3.icr

dataset_prepared_path: last_run_prepared
dataset_processes: 1
val_set_size: 0.02
output_dir: ./outputs/out

sequence_len: 2048
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

@winglian
Copy link
Collaborator Author

Screenshot 2024-08-25 at 4 12 02 PM and this is with the gc/clear cache on each step

@winglian winglian merged commit 17af1d7 into main Aug 26, 2024
7 checks passed
@winglian winglian deleted the dpo-mem-leak branch August 26, 2024 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants