From e8d47e521ad27c0fb6890d70ebedad9652a0df7a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 22 Aug 2024 22:44:33 -0400 Subject: [PATCH 1/2] clear cuda cache to help with memory leak/creep --- src/axolotl/core/trainer_builder.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1a073ca04..855974e59 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,6 +4,7 @@ """ import abc +import gc import importlib import importlib.util import logging @@ -15,11 +16,12 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -997,6 +999,14 @@ def tokenize_row( res[key] = res[key][1:] return res + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs) + torch.cuda.empty_cache() + gc.collect() + return loss + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ From 3b612b1a5a66d153b11500029f94a8a118662a0e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 Aug 2024 15:44:02 -0400 Subject: [PATCH 2/2] reverse order of gc --- src/axolotl/core/trainer_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 855974e59..656ded255 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1003,8 +1003,8 @@ def training_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> torch.Tensor: loss: torch.Tensor = super().training_step(model, inputs) - torch.cuda.empty_cache() gc.collect() + torch.cuda.empty_cache() return loss