Skip to content

Commit

Permalink
[feature] add the ability to recover shuffling datasets in checkpoint. (
Browse files Browse the repository at this point in the history
#247)

[feature] adds the ability to recover shuffling datasets in checkpoint.
[fix] bug fixed, epoch for load checkpoint cannot exceed the config one.
  • Loading branch information
LongzhuoWang authored Sep 14, 2024
1 parent e5cbb21 commit e72fe4c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 17 deletions.
66 changes: 65 additions & 1 deletion mlora/executor/task/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import shutil
from abc import abstractmethod
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -30,6 +32,9 @@ class Task:
# need_terminal_ the llm name just for export the config file
llm_name_: str

recover_folder_: str | None
shuffle_data_cache_path_: str | None

def __init__(self, config: TaskConfig, llm_name: str) -> None:
self.config_ = config

Expand All @@ -41,6 +46,9 @@ def __init__(self, config: TaskConfig, llm_name: str) -> None:

self.llm_name_ = llm_name

self.recover_folder_ = None
self.shuffle_data_cache_path_ = None

@abstractmethod
def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer):
# task prepare for execute
Expand Down Expand Up @@ -70,10 +78,54 @@ def notify_terminate(self):
def is_terminate(self) -> bool:
return self.terminate_

def _del_cache_file(self):
if self.shuffle_data_cache_path_ is None:
return
cache_path: str = self.shuffle_data_cache_path_
# If exist cache files, then delete them.
if os.path.exists(cache_path):
os.remove(cache_path)
# If the cache folder is empty, delete it.
dir, _ = os.path.split(cache_path)
if os.path.exists(dir) and len(os.listdir(dir)) == 0:
os.rmdir(dir)

def _shuffle_data(self, data):
# If data preprocess_type is shuffle, create a cache folder,
# to store shuffled data and use it for saving checkpoints.
data_name: str = ""
if self.config_.dataset_ is not None:
data_name = self.config_.dataset_.name_ + "_"
# warning: The cache path can only use up to first level dirs,
# otherwise will result in an error.
self.shuffle_data_cache_path_ = ".cache/shuffle_" + data_name + self.task_name()
# Clear the cache files before use.
self._del_cache_file()
cache_dir, _ = os.path.split(self.shuffle_data_cache_path_)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

# If exist checkpoint, copy the shuffle_data in checkpoint to cache path.
if self.recover_folder_ is not None:
recover_data_path: str = (
self.context_.path_
+ os.sep
+ self.recover_folder_
+ os.sep
+ "shuffle_data"
)
shutil.copy(recover_data_path, self.shuffle_data_cache_path_)
logging.info(
"Found shuffled data successfully, data status has been restored."
)
return data.shuffle(
indices_cache_file_names={"data_points": self.shuffle_data_cache_path_}
)

def _pre_dataset(self):
preprocess_func: Dict[str, Callable] = {
"default": lambda data: data,
"shuffle": lambda data: data.shuffle(),
"shuffle": lambda data: self._shuffle_data(data),
"sort": lambda data: data.sort(),
}

Expand All @@ -94,6 +146,7 @@ def _pre_dataset(self):
if preprocess_type not in preprocess_func:
raise NotImplementedError

# Process data according to the data preprocess_type.
data = preprocess_func[preprocess_type](data)
logging.info(
f"Adapter {self.config_.adapter_.name_} "
Expand Down Expand Up @@ -125,6 +178,17 @@ def _expand_batch_tokens(

return ret_batch_tokens, ret_batch_masks

def _save_data(self, output_dir: str):
if self.config_.dataset_ is None or self.shuffle_data_cache_path_ is None:
return
preprocess_type: str = self.config_.dataset_.preprocess_
cache_path: str = self.shuffle_data_cache_path_
# If data preprocess_type is shuffle,
# save shuffle data from cache path to the checkpoint.
if preprocess_type == "shuffle":
sheffle_data_path = output_dir + os.sep + "shuffle_data"
shutil.copy(cache_path, sheffle_data_path)

def adapter_model(self) -> List[AdapterModel]:
return [self.context_.adapter_model()]

Expand Down
46 changes: 30 additions & 16 deletions mlora/executor/task/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .task import Task


def _get_context_state_from_dir_name(dir_name: str) -> Tuple[int, int, int]:
def _get_context_state_from_folder_name(dir_name: str) -> Tuple[int, int, int]:
split_group = dir_name.split("_")

epoch = int(split_group[1])
Expand Down Expand Up @@ -44,11 +44,11 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz
self.tokenizer_ = tokenizer
# prepare the context and the dataset
# NOTE: how to recover the sort of dataset
self._pre_dataset()
self._pre_context(linears_info)
self._pre_recover_context()
self._pre_dataset()

def _get_recover_folder(self) -> str | None:
def _get_recover_dir(self) -> str | None:
if not os.path.isdir(self.context_.path_):
return None

Expand All @@ -67,31 +67,40 @@ def is_recover_dir(dir_name: str) -> bool:
return None

max_step = -1
to_recover_folder: str | None = None
max_epoch = -1
to_recover_dir: str | None = None
# Find the most suitable checkpoint as the recovery folder
for folder in recover_folders:
base_folder = os.path.basename(os.path.normpath(folder))
step, epoch, data_idx = _get_context_state_from_dir_name(base_folder)
if step is not None and step > max_step:
max_step = max(max_step, step)
step, epoch, data_idx = _get_context_state_from_folder_name(base_folder)
# skip checkpoint that do not meet the condition
if step is None or epoch > self.config_.num_epochs_:
continue
# Take maximum step, and take maximum epoch when steps are equal
if step > max_step or (step == max_step and epoch > max_epoch):
max_step = step
max_epoch = epoch
self.now_epoch_ = epoch
self.now_data_idx_ = data_idx
self.now_step_ = step
to_recover_folder = os.path.join(self.context_.path_, folder)

return to_recover_folder
# Set the recovery_folder name for restoring shuffle_data (if exist).
self.recover_folder_ = folder
to_recover_dir = os.path.join(self.context_.path_, folder)
return to_recover_dir

def _pre_recover_context(self):
to_recover_folder = self._get_recover_folder()
if to_recover_folder is None:
to_recover_dir = self._get_recover_dir()
if to_recover_dir is None:
return

logging.info(
f"Task {self.task_name()} have recover directory {to_recover_folder}"
f"Task {self.task_name()} have recover directory {to_recover_dir}"
" need to recover."
)
self.checkpoint_ = True

# get the optimizer read the file from now_epoch
checkpoint = torch.load(to_recover_folder + os.sep + "checkpoint.bin")
checkpoint = torch.load(to_recover_dir + os.sep + "checkpoint.bin")

self.context_.recover_weight(checkpoint["weight_dict"])
self.context_.recover_optimizer(checkpoint["state_dict"])
Expand Down Expand Up @@ -176,14 +185,14 @@ def _expand_batch_tokens(
def _save(self, is_checkpoint: bool = False, additional_info: Dict[str, str] = {}):
output_dir = self.context_.path_
if is_checkpoint:
checkpoint_dir = "checkpoint_" + "_".join(
checkpoint_folder = "checkpoint_" + "_".join(
[
str(self.now_step_),
str(self.now_epoch_),
str(self.now_data_idx_),
]
)
output_dir = self.context_.path_ + os.sep + checkpoint_dir
output_dir = self.context_.path_ + os.sep + checkpoint_folder

if not os.path.exists(output_dir):
os.makedirs(output_dir)
Expand All @@ -197,6 +206,9 @@ def _save(self, is_checkpoint: bool = False, additional_info: Dict[str, str] = {
},
output_dir + os.sep + "checkpoint.bin",
)
# Save checkpoint for shuffle_data.
self._save_data(output_dir)

else:
torch.save(
self.context_.weight_dict(), output_dir + os.sep + "adapter_model.bin"
Expand All @@ -213,6 +225,8 @@ def _save(self, is_checkpoint: bool = False, additional_info: Dict[str, str] = {
@override
def done(self):
self._save(is_checkpoint=False)
# Delete the cache file.
self._del_cache_file()
# release the context
del self.context_

Expand Down

0 comments on commit e72fe4c

Please sign in to comment.