Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fault tolenrence and restore #233

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/code-formatter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

black ./mlora
black ./mlora_cli
isort ./mlora --profile black
isort ./mlora_cli --profile black
36 changes: 36 additions & 0 deletions demo/checkpoint/checkpoint_case_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
dispatcher:
name: "default"
concurrency_num: 1
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "default"
adapters:
- name: "lora_0"
type: "lora"
path: "adapters/lora_sft_checkpoint"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "train"
name: "task_0"
adapter: "lora_0"
dataset: "data"
batch_size: 16
mini_batch_size: 16
num_epochs: 2
cutoff_len: 256
save_step: 5
36 changes: 36 additions & 0 deletions demo/checkpoint/checkpoint_case_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
dispatcher:
name: "default"
concurrency_num: 1
datasets:
- name: "data"
data: "demo/data.json"
prompt: "demo/prompt.yaml"
prompt_type: "instruction"
preprocess: "default"
adapters:
- name: "lora_0"
type: "lora"
path: "adapters/lora_sft_checkpoint"
optimizer: "adamw"
lr: 3e-4
r: 32
alpha: 64
dropout: 0.05
target_modules:
q_proj: true
k_proj: true
v_proj: true
o_proj: true
gate_proj: false
down_proj: false
up_proj: false
tasks:
- type: "train"
name: "task_0"
adapter: "lora_0"
dataset: "data"
batch_size: 16
mini_batch_size: 16
num_epochs: 10
cutoff_len: 256
save_step: 10
2 changes: 1 addition & 1 deletion mlora/config/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config: Dict[str, str]) -> None:
self.init(self.__params_map, config)

@abstractmethod
def to_fn_parameters(self) -> Dict[str, str]: ...
def to_fn_parameters(self) -> Dict[str, Any]: ...


class CosineLRSchedulerConfig(LRSchedulerConfig):
Expand Down
2 changes: 1 addition & 1 deletion mlora/executor/context/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def switch_device(self, device: str) -> None:
return

for _, adapter in self.adapter_model_.items():
self.switch_list_tensor(adapter.get_tensors(), device)
self.switch_list_tensor(adapter.get_all_tensors(), device)

self.device_ = device

Expand Down
72 changes: 41 additions & 31 deletions mlora/executor/context/lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging
import os
from collections import OrderedDict
from typing import Dict, override

Expand All @@ -14,8 +12,10 @@
from .train import TrainTaskContext


def _load_lora_weight(
obj: TaskContext, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]
def _init_lora_weight(
context: TaskContext,
config: LoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
):
# init the weight
for linear_name, linear_info in linears_info.items():
Expand All @@ -25,37 +25,16 @@ def _load_lora_weight(
if config.target_[target_name] is not True:
continue

obj.adapter_model_[linear_name] = LoRA(
context.adapter_model_[linear_name] = LoRA(
config.name_,
linear_info.in_dim_,
linear_info.out_dim_,
config.r_,
config.alpha_,
config.dropout_,
)
weight_dict = None

if os.path.isdir(obj.path_):
logging.info(f"Adapter {obj.name_}:{obj.path_} weight exist, load from file.")
weight_dict = torch.load(f"{obj.path_}{os.sep}adapter_model.bin")
prefix_name = "base_model.model.model."
else:
logging.info(
f"Adapter {obj.name_}:{obj.path_} weight not exist, use the default weight."
)

for name, module in obj.adapter_model_.items():
lora_a = (
None
if weight_dict is None
else weight_dict[prefix_name + name + ".lora_A.weight"]
)
lora_b = (
None
if weight_dict is None
else weight_dict[prefix_name + name + ".lora_B.weight"]
)
module.init_weight(lora_a, lora_b)
for _, module in context.adapter_model_.items():
module.init_weight(None, None)


class InferenceLoRAContext(InferenceTaskContext):
Expand All @@ -68,23 +47,26 @@ def __init__(

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_load_lora_weight(self, self.config_, linears_info)
_init_lora_weight(self, self.config_, linears_info)


class TrainLoRAContext(TrainTaskContext):
config_: LoRAConfig

def __init__(
self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]
self,
config: LoRAConfig,
linears_info: OrderedDict[str, LinearInfo],
) -> None:
super().__init__(config, linears_info)

self.loss_fn_ = torch.nn.CrossEntropyLoss()

@override
def load_weight(self, linears_info: OrderedDict[str, LinearInfo]):
_load_lora_weight(self, self.config_, linears_info)
_init_lora_weight(self, self.config_, linears_info)

@override
def weight_dict(self) -> Dict[str, torch.Tensor]:
# base_model.model.model.layers.{0}.self_attn.{q_proj}.{lora_A}.weight
# base_model.model.model.layers.{0}.mlp.{gate_proj}.{lora_A}.weight
Expand All @@ -95,3 +77,31 @@ def weight_dict(self) -> Dict[str, torch.Tensor]:
ret_val[prefix_name + ".lora_B.weight"] = adapter.lora_b_

return ret_val

@override
def state_dict(self) -> Dict[str, torch.Tensor]:
return self.optimizer_.state_dict()

@override
def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]):
assert self.optimizer_ is not None
self.optimizer_.load_state_dict(state_dict)

@override
def recover_lr(self, last_epoch: int):
# the last_epoch is increased every time you call .step() of scheduler
# different from the train epoch, be careful
if self.lr_scheduler_ is None:
return

# we recreate the lr scheduler
self.create_lr_scheduler(self.config_.lr_scheduler_config_, last_epoch)

@override
def recover_weight(self, weight_dict: Dict[str, torch.Tensor]):
assert weight_dict is not None
prefix_name = "base_model.model.model."
for name, module in self.adapter_model_.items():
lora_a = weight_dict[prefix_name + name + ".lora_A.weight"]
lora_b = weight_dict[prefix_name + name + ".lora_B.weight"]
module.init_weight(lora_a, lora_b)
34 changes: 27 additions & 7 deletions mlora/executor/context/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TrainTaskContext(TaskContext):
lr_scheduler_: torch.optim.lr_scheduler.LRScheduler | None

def __init__(
self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]
self,
config: AdapterConfig,
linears_info: OrderedDict[str, LinearInfo],
) -> None:
super().__init__(config)

# load the adapter's weight
self.load_weight(linears_info)

for module in self.adapter_model_.values():
module.enable_grad()

Expand All @@ -38,6 +39,19 @@ def __init__(
@abstractmethod
def weight_dict(self) -> Dict[str, torch.Tensor]: ...

@abstractmethod
def state_dict(self) -> Dict[str, torch.Tensor]: ...

# recover_optimizer
@abstractmethod
def recover_optimizer(self, state_dict: Dict[str, torch.Tensor]): ...

@abstractmethod
def recover_lr(self, now_epoch: int): ...

@abstractmethod
def recover_weight(self, weight_dict: Dict[str, torch.Tensor]): ...

def create_optimizer(self, optim_config: OptimizerConfig | None):
assert optim_config is not None

Expand All @@ -46,31 +60,37 @@ def create_optimizer(self, optim_config: OptimizerConfig | None):

parameters: List[torch.Tensor] = []
for adapter in self.adapter_model_.values():
parameters.extend(adapter.get_tensors())
parameters.extend(adapter.get_trainable_tensors())

self.optimizer_ = OPTIMIZER_CLASS[optimizer_type_](
parameters, **optim_config.to_fn_parameters()
)

def create_lr_scheduler(self, lr_scheduler_config: LRSchedulerConfig | None):
def create_lr_scheduler(
self, lr_scheduler_config: LRSchedulerConfig | None, last_epoch: int = -1
):
assert self.optimizer_ is not None

if lr_scheduler_config is None:
self.lr_scheduler_ = None
return

lr_scheduler_type_ = lr_scheduler_config.lr_scheduler_
assert lr_scheduler_type_ in LR_SCHEDULER_CLASS

kwargs = lr_scheduler_config.to_fn_parameters()
kwargs["last_epoch"] = last_epoch

self.lr_scheduler_ = LR_SCHEDULER_CLASS[lr_scheduler_type_](
self.optimizer_, **lr_scheduler_config.to_fn_parameters() # type: ignore
self.optimizer_,
**kwargs, # type: ignore
)

def switch_device(self, device: str) -> None:
if self.device_ == device:
return

for _, adapter in self.adapter_model_.items():
self.switch_list_tensor(adapter.get_tensors(), device)
self.switch_list_tensor(adapter.get_all_tensors(), device)

self.switch_optimizer(device)

Expand Down
Loading
Loading