diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1a073ca04..656ded255 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,6 +4,7 @@ """ import abc +import gc import importlib import importlib.util import logging @@ -15,11 +16,12 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union import torch import transformers from datasets import Dataset +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -997,6 +999,14 @@ def tokenize_row( res[key] = res[key][1:] return res + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs) + gc.collect() + torch.cuda.empty_cache() + return loss + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """