diff --git a/.github/workflows/code-formatter.sh b/.github/workflows/code-formatter.sh new file mode 100755 index 00000000..6e089585 --- /dev/null +++ b/.github/workflows/code-formatter.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +black ./mlora +black ./mlora_cli +isort ./mlora --profile black +isort ./mlora_cli --profile black \ No newline at end of file diff --git a/demo/checkpoint/checkpoint_case_1.yaml b/demo/checkpoint/checkpoint_case_1.yaml new file mode 100644 index 00000000..c2710a30 --- /dev/null +++ b/demo/checkpoint/checkpoint_case_1.yaml @@ -0,0 +1,36 @@ +dispatcher: + name: "default" + concurrency_num: 1 +datasets: + - name: "data" + data: "demo/data.json" + prompt: "demo/prompt.yaml" + prompt_type: "instruction" + preprocess: "default" +adapters: + - name: "lora_0" + type: "lora" + path: "adapters/lora_sft_checkpoint" + optimizer: "adamw" + lr: 3e-4 + r: 32 + alpha: 64 + dropout: 0.05 + target_modules: + q_proj: true + k_proj: true + v_proj: true + o_proj: true + gate_proj: false + down_proj: false + up_proj: false +tasks: + - type: "train" + name: "task_0" + adapter: "lora_0" + dataset: "data" + batch_size: 16 + mini_batch_size: 16 + num_epochs: 2 + cutoff_len: 256 + save_step: 5 diff --git a/demo/checkpoint/checkpoint_case_2.yaml b/demo/checkpoint/checkpoint_case_2.yaml new file mode 100644 index 00000000..39805de0 --- /dev/null +++ b/demo/checkpoint/checkpoint_case_2.yaml @@ -0,0 +1,36 @@ +dispatcher: + name: "default" + concurrency_num: 1 +datasets: + - name: "data" + data: "demo/data.json" + prompt: "demo/prompt.yaml" + prompt_type: "instruction" + preprocess: "default" +adapters: + - name: "lora_0" + type: "lora" + path: "adapters/lora_sft_checkpoint" + optimizer: "adamw" + lr: 3e-4 + r: 32 + alpha: 64 + dropout: 0.05 + target_modules: + q_proj: true + k_proj: true + v_proj: true + o_proj: true + gate_proj: false + down_proj: false + up_proj: false +tasks: + - type: "train" + name: "task_0" + adapter: "lora_0" + dataset: "data" + batch_size: 16 + mini_batch_size: 16 + num_epochs: 10 + cutoff_len: 256 + save_step: 10 diff --git a/mlora/config/lr_scheduler.py b/mlora/config/lr_scheduler.py index c4c24659..56dba95d 100644 --- a/mlora/config/lr_scheduler.py +++ b/mlora/config/lr_scheduler.py @@ -14,7 +14,7 @@ def __init__(self, config: Dict[str, str]) -> None: self.init(self.__params_map, config) @abstractmethod - def to_fn_parameters(self) -> Dict[str, str]: ... + def to_fn_parameters(self) -> Dict[str, Any]: ... class CosineLRSchedulerConfig(LRSchedulerConfig): diff --git a/mlora/executor/context/inference.py b/mlora/executor/context/inference.py index 0f1eb715..67883a2f 100644 --- a/mlora/executor/context/inference.py +++ b/mlora/executor/context/inference.py @@ -24,7 +24,7 @@ def switch_device(self, device: str) -> None: return for _, adapter in self.adapter_model_.items(): - self.switch_list_tensor(adapter.get_tensors(), device) + self.switch_list_tensor(adapter.get_all_tensors(), device) self.device_ = device diff --git a/mlora/executor/context/lora.py b/mlora/executor/context/lora.py index 5cfa0a5b..159a827f 100644 --- a/mlora/executor/context/lora.py +++ b/mlora/executor/context/lora.py @@ -1,5 +1,3 @@ -import logging -import os from collections import OrderedDict from typing import Dict, override @@ -14,8 +12,10 @@ from .train import TrainTaskContext -def _load_lora_weight( - obj: TaskContext, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo] +def _init_lora_weight( + context: TaskContext, + config: LoRAConfig, + linears_info: OrderedDict[str, LinearInfo], ): # init the weight for linear_name, linear_info in linears_info.items(): @@ -25,7 +25,7 @@ def _load_lora_weight( if config.target_[target_name] is not True: continue - obj.adapter_model_[linear_name] = LoRA( + context.adapter_model_[linear_name] = LoRA( config.name_, linear_info.in_dim_, linear_info.out_dim_, @@ -33,29 +33,8 @@ def _load_lora_weight( config.alpha_, config.dropout_, ) - weight_dict = None - - if os.path.isdir(obj.path_): - logging.info(f"Adapter {obj.name_}:{obj.path_} weight exist, load from file.") - weight_dict = torch.load(f"{obj.path_}{os.sep}adapter_model.bin") - prefix_name = "base_model.model.model." - else: - logging.info( - f"Adapter {obj.name_}:{obj.path_} weight not exist, use the default weight." - ) - - for name, module in obj.adapter_model_.items(): - lora_a = ( - None - if weight_dict is None - else weight_dict[prefix_name + name + ".lora_A.weight"] - ) - lora_b = ( - None - if weight_dict is None - else weight_dict[prefix_name + name + ".lora_B.weight"] - ) - module.init_weight(lora_a, lora_b) + for _, module in context.adapter_model_.items(): + module.init_weight(None, None) class InferenceLoRAContext(InferenceTaskContext): @@ -68,14 +47,16 @@ def __init__( @override def load_weight(self, linears_info: OrderedDict[str, LinearInfo]): - _load_lora_weight(self, self.config_, linears_info) + _init_lora_weight(self, self.config_, linears_info) class TrainLoRAContext(TrainTaskContext): config_: LoRAConfig def __init__( - self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo] + self, + config: LoRAConfig, + linears_info: OrderedDict[str, LinearInfo], ) -> None: super().__init__(config, linears_info) @@ -83,8 +64,9 @@ def __init__( @override def load_weight(self, linears_info: OrderedDict[str, LinearInfo]): - _load_lora_weight(self, self.config_, linears_info) + _init_lora_weight(self, self.config_, linears_info) + @override def weight_dict(self) -> Dict[str, torch.Tensor]: # base_model.model.model.layers.{0}.self_attn.{q_proj}.{lora_A}.weight # base_model.model.model.layers.{0}.mlp.{gate_proj}.{lora_A}.weight @@ -95,3 +77,31 @@ def weight_dict(self) -> Dict[str, torch.Tensor]: ret_val[prefix_name + ".lora_B.weight"] = adapter.lora_b_ return ret_val + + @override + def state_dict(self) -> Dict[str, torch.Tensor]: + return self.optimizer_.state_dict() + + @override + def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]): + assert self.optimizer_ is not None + self.optimizer_.load_state_dict(state_dict) + + @override + def recover_lr(self, last_epoch: int): + # the last_epoch is increased every time you call .step() of scheduler + # different from the train epoch, be careful + if self.lr_scheduler_ is None: + return + + # we recreate the lr scheduler + self.create_lr_scheduler(self.config_.lr_scheduler_config_, last_epoch) + + @override + def recover_weight(self, weight_dict: Dict[str, torch.Tensor]): + assert weight_dict is not None + prefix_name = "base_model.model.model." + for name, module in self.adapter_model_.items(): + lora_a = weight_dict[prefix_name + name + ".lora_A.weight"] + lora_b = weight_dict[prefix_name + name + ".lora_B.weight"] + module.init_weight(lora_a, lora_b) diff --git a/mlora/executor/context/train.py b/mlora/executor/context/train.py index c336c85f..9e437e9d 100644 --- a/mlora/executor/context/train.py +++ b/mlora/executor/context/train.py @@ -22,13 +22,14 @@ class TrainTaskContext(TaskContext): lr_scheduler_: torch.optim.lr_scheduler.LRScheduler | None def __init__( - self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo] + self, + config: AdapterConfig, + linears_info: OrderedDict[str, LinearInfo], ) -> None: super().__init__(config) # load the adapter's weight self.load_weight(linears_info) - for module in self.adapter_model_.values(): module.enable_grad() @@ -38,6 +39,19 @@ def __init__( @abstractmethod def weight_dict(self) -> Dict[str, torch.Tensor]: ... + @abstractmethod + def state_dict(self) -> Dict[str, torch.Tensor]: ... + + # recover_optimizer + @abstractmethod + def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]): ... + + @abstractmethod + def recover_lr(self, now_epoch: int): ... + + @abstractmethod + def recover_weight(self, weight_dict: Dict[str, torch.Tensor]): ... + def create_optimizer(self, optim_config: OptimizerConfig | None): assert optim_config is not None @@ -46,13 +60,15 @@ def create_optimizer(self, optim_config: OptimizerConfig | None): parameters: List[torch.Tensor] = [] for adapter in self.adapter_model_.values(): - parameters.extend(adapter.get_tensors()) + parameters.extend(adapter.get_trainable_tensors()) self.optimizer_ = OPTIMIZER_CLASS[optimizer_type_]( parameters, **optim_config.to_fn_parameters() ) - def create_lr_scheduler(self, lr_scheduler_config: LRSchedulerConfig | None): + def create_lr_scheduler( + self, lr_scheduler_config: LRSchedulerConfig | None, last_epoch: int = -1 + ): assert self.optimizer_ is not None if lr_scheduler_config is None: @@ -60,9 +76,13 @@ def create_lr_scheduler(self, lr_scheduler_config: LRSchedulerConfig | None): return lr_scheduler_type_ = lr_scheduler_config.lr_scheduler_ - assert lr_scheduler_type_ in LR_SCHEDULER_CLASS + + kwargs = lr_scheduler_config.to_fn_parameters() + kwargs["last_epoch"] = last_epoch + self.lr_scheduler_ = LR_SCHEDULER_CLASS[lr_scheduler_type_]( - self.optimizer_, **lr_scheduler_config.to_fn_parameters() # type: ignore + self.optimizer_, + **kwargs, # type: ignore ) def switch_device(self, device: str) -> None: @@ -70,7 +90,7 @@ def switch_device(self, device: str) -> None: return for _, adapter in self.adapter_model_.items(): - self.switch_list_tensor(adapter.get_tensors(), device) + self.switch_list_tensor(adapter.get_all_tensors(), device) self.switch_optimizer(device) diff --git a/mlora/executor/task/train_task.py b/mlora/executor/task/train_task.py index a69129e1..e62b90e7 100644 --- a/mlora/executor/task/train_task.py +++ b/mlora/executor/task/train_task.py @@ -14,6 +14,16 @@ from .task import Task +def _get_context_state_from_dir_name(dir_name: str) -> Tuple[int, int, int]: + split_group = dir_name.split("_") + + epoch = int(split_group[1]) + data_idx = int(split_group[2]) + step = int(split_group[3]) + + return epoch, data_idx, step + + class TrainTask(Task): now_epoch_: int @@ -32,8 +42,61 @@ def is_done(self) -> bool: def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer): self.tokenizer_ = tokenizer # prepare the context and the dataset + # NOTE: how to recover the sort of dataset self._pre_dataset() self._pre_context(linears_info) + self._pre_recover_context() + + def _get_recover_folder(self) -> str | None: + if not os.path.isdir(self.context_.path_): + return None + + def is_recover_dir(dir_name: str) -> bool: + if "checkpoint" not in dir_name: + return False + + if not os.path.isdir(os.path.join(self.context_.path_, dir_name)): + return False + + return True + + recover_folders = list(filter(is_recover_dir, os.listdir(self.context_.path_))) + + if recover_folders is None or len(recover_folders) <= 0: + return None + + max_step = -1 + to_recover_folder: str | None = None + for folder in recover_folders: + base_folder = os.path.basename(os.path.normpath(folder)) + step, epoch, data_idx = _get_context_state_from_dir_name(base_folder) + if step is not None and step > max_step: + max_step = max(max_step, step) + self.now_epoch_ = epoch + self.now_data_idx_ = data_idx + self.now_step_ = step + to_recover_folder = os.path.join(self.context_.path_, folder) + + return to_recover_folder + + def _pre_recover_context(self): + to_recover_folder = self._get_recover_folder() + if to_recover_folder is None: + return + + logging.info( + f"Task {self.task_name()} have recover directory {to_recover_folder}" + "need to recover." + ) + + # get the optimizer read the file from now_epoch + checkpoint = torch.load(to_recover_folder + os.sep + "checkpoint.bin") + + self.context_.recover_weight(checkpoint["weight_dict"]) + self.context_.recover_optimizer(checkpoint["state_dict"]) + # recompute the lr's epoch for recover + lr_epoch = self.now_step_ // self.config_.accumulate_step_ + self.context_.recover_lr(lr_epoch) @override def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: @@ -104,17 +167,34 @@ def _expand_batch_tokens( return ret_batch_tokens, ret_batch_masks - def _save(self, dir_suffix: str = "", additional_info: Dict[str, str] = {}): + def _save(self, is_checkpoint: bool = False, additional_info: Dict[str, str] = {}): output_dir = self.context_.path_ - if dir_suffix != "": - output_dir += os.sep + self.context_.path_ + "_" + dir_suffix + if is_checkpoint: + checkpoint_dir = "checkpoint_" + "_".join( + [ + str(self.now_step_), + str(self.now_epoch_), + str(self.now_data_idx_), + ] + ) + output_dir = self.context_.path_ + os.sep + checkpoint_dir if not os.path.exists(output_dir): os.makedirs(output_dir) - torch.save( - self.context_.weight_dict(), output_dir + os.sep + "adapter_model.bin" - ) + # save to disk, if save checkpoint, we need also save the state dict + if is_checkpoint: + torch.save( + { + "weight_dict": self.context_.weight_dict(), + "state_dict": self.context_.state_dict(), + }, + output_dir + os.sep + "checkpoint.bin", + ) + else: + torch.save( + self.context_.weight_dict(), output_dir + os.sep + "adapter_model.bin" + ) adapter_config: Dict[str, str] = {} adapter_config["base_model_name_or_path"] = self.llm_name_ @@ -126,7 +206,7 @@ def _save(self, dir_suffix: str = "", additional_info: Dict[str, str] = {}): @override def done(self): - self._save() + self._save(is_checkpoint=False) # release the context del self.context_ @@ -137,14 +217,14 @@ def terminate(self): @override def step(self): stepd: bool = False + need_checkpoint: bool = False if self.now_step_ % self.config_.accumulate_step_ == 0: stepd = True self.context_.step() - # to save the model if self.now_step_ % self.config_.save_step_ == 0: - self._save(f"{self.now_step_}") + need_checkpoint = True self.now_step_ += 1 self.now_data_idx_ += self.config_.mini_batch_size_ @@ -153,6 +233,11 @@ def step(self): self.now_epoch_ += 1 self.now_data_idx_ = 0 + # to save the checkpoint, must ensure the order + # beacuse we need recover the state + if need_checkpoint: + self._save(is_checkpoint=True) + # task finish we also need to step if not stepd and self.now_epoch_ >= self.config_.num_epochs_: self.context_.step() diff --git a/mlora/model/modules/adapter.py b/mlora/model/modules/adapter.py index 6f0e4b9e..ac55ce3d 100644 --- a/mlora/model/modules/adapter.py +++ b/mlora/model/modules/adapter.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, List +from typing import List, MutableMapping import torch @@ -15,18 +15,21 @@ def __init__(self, adapter_type: str, adapter_name: str): self.adapter_name_ = adapter_name @abstractmethod - def get_tensors(self) -> List[torch.Tensor]: ... + def get_trainable_tensors(self) -> List[torch.Tensor]: ... + + @abstractmethod + def get_all_tensors(self) -> List[torch.Tensor]: ... def disable_grad(self): - for tensor in self.get_tensors(): + for tensor in self.get_trainable_tensors(): tensor.requires_grad_(False) assert tensor.requires_grad is False def enable_grad(self): - for tensor in self.get_tensors(): + for tensor in self.get_trainable_tensors(): tensor.requires_grad_(True) assert tensor.requires_grad is True assert tensor.is_leaf -AdapterModel = Dict[str, Adapter] +AdapterModel = MutableMapping[str, Adapter] diff --git a/mlora/model/modules/lora.py b/mlora/model/modules/lora.py index c78806b5..d8afead8 100644 --- a/mlora/model/modules/lora.py +++ b/mlora/model/modules/lora.py @@ -185,26 +185,21 @@ def __init__( def init_weight( self, lora_a: torch.Tensor | None = None, lora_b: torch.Tensor | None = None ): - if lora_a is None: - torch.nn.init.kaiming_normal_(self.lora_a_, a=math.sqrt(5)) - else: - self.lora_a_ = ( - lora_a.to("cpu") - .detach() - .clone() - .to(dtype=torch.float32) - .requires_grad_(True) - ) + # Gradient calculations are temporarily disabled for copy or init + with torch.no_grad(): + if lora_a is None: + torch.nn.init.kaiming_normal_(self.lora_a_, a=math.sqrt(5)) + else: + self.lora_a_.copy_(lora_a) - if lora_b is not None: - self.lora_b_ = ( - lora_b.to("cpu") - .detach() - .clone() - .to(dtype=torch.float32) - .requires_grad_(True) - ) + # lora_b is zero so do not need to init it + if lora_b is not None: + self.lora_b_.copy_(lora_b) + + @override + def get_trainable_tensors(self) -> List[torch.Tensor]: + return [self.lora_a_, self.lora_b_] @override - def get_tensors(self) -> List[torch.Tensor]: + def get_all_tensors(self) -> List[torch.Tensor]: return [self.lora_a_, self.lora_b_] diff --git a/tests/finetune_all_case.sh b/tests/finetune_all_case.sh index caacf5de..06ad5356 100755 --- a/tests/finetune_all_case.sh +++ b/tests/finetune_all_case.sh @@ -2,11 +2,12 @@ declare -a test_case_yamls=( "demo/lora/lora_case_1.yaml" + "demo/checkpoint/checkpoint_case_1.yaml" + "demo/checkpoint/checkpoint_case_2.yaml" "demo/loraplus/loraplus_case_1.yaml" "demo/dpo/dpo_case_1.yaml" "demo/dpo/dpo_case_2.yaml" "demo/dpo/dpo_case_3.yaml" - "demo/cpo/cpo_case_1.yaml" ) set -x