Skip to content

Commit

Permalink
fault tolenrence and restore
Browse files Browse the repository at this point in the history
  • Loading branch information
cainiaogoroad committed Jul 5, 2024
1 parent 73c889a commit 4910946
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 14 deletions.
14 changes: 11 additions & 3 deletions mlora/config/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, config: Dict[str, str]) -> None:
self.init(self.__params_map, config)

@abstractmethod
def to_fn_parameters(self) -> Dict[str, Any]: ...
def to_fn_parameters(self, now_epoch: int | None = None) -> Dict[str, str]: ...


class CosineLRSchedulerConfig(LRSchedulerConfig):
Expand All @@ -31,8 +31,16 @@ def __init__(self, config: Dict[str, str]) -> None:
self.eta_min_ = int(self.eta_min_)

@override
def to_fn_parameters(self) -> Dict[str, Any]:
return {"T_max": float(self.t_max_), "eta_min": float(self.eta_min_)}
def to_fn_parameters(self, now_epoch: int | None = None) -> Dict[str, str]:
ret_parameters: Dict[str, Any] = {
"T_max": float(self.t_max_),
"eta_min": float(self.eta_min_),
}

if now_epoch is not None:
ret_parameters["last_epoch"] = int(now_epoch)

return ret_parameters


LRSCHEDULERCONFIG_CLASS = {
Expand Down
2 changes: 1 addition & 1 deletion mlora/executor/context/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from collections import OrderedDict
from typing import Callable, Dict, List, Type
from typing import Callable, Dict, List, Optional, Type

import torch

Expand Down
2 changes: 1 addition & 1 deletion mlora/executor/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _pre_dataset(self):
data = preprocess_func[preprocess_type](data)
logging.info(
f"Adapter {self.config_.adapter_.name_} "
f"data size: {len(data["data_points"])}"
f"data size: {len(data['data_points'])}"
)

for _, data_point in tqdm(enumerate(data["data_points"])):
Expand Down
3 changes: 1 addition & 2 deletions mlora/executor/task/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ def _get_context_state_from_dir_name(dir_name: str) -> Tuple[int, int, int]:

class TrainTask(Task):
now_epoch_: int

context_: TrainTaskContext
config_: TrainTaskConfig
recover_folder: str

def __init__(self, config: TaskConfig, llm_name: str) -> None:
super().__init__(config, llm_name)
self.now_epoch_ = 1

@override
def is_done(self) -> bool:
Expand Down
11 changes: 5 additions & 6 deletions mlora/model/llm/model_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ def output_layer_forward():
}

module_name = self.name()
assert (
module_name in forward_func_dict
), f"error module name {
module_name}"
assert module_name in forward_func_dict, f"error module name {module_name}"

return forward_func_dict[module_name]()

Expand Down Expand Up @@ -238,8 +235,10 @@ def create_device_map() -> str | Dict[str, str]:
llama_model = AutoModelForCausalLM.from_pretrained(path, **additional_load_args)

if llama_model.config.model_type not in LlamaCompatibleModelTypes:
assert f"unsupported model type {
llama_model.config.model_type}, loading with llama compatible mode."
assert (
f"unsupported model type {llama_model.config.model_type},"
f" loading with llama compatible mode."
)

logging.info(
f"loading llama compatible model - {llama_model.config.model_type}"
Expand Down
14 changes: 14 additions & 0 deletions mlora/model/modules/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ def __init__(
self.alpha_: int = alpha
self.dropout_: float = dropout
self.scaling_: float = alpha / r
self.out_dim: int = out_dim

def init_weight(
self, lora_a: torch.Tensor | None = None, lora_b: torch.Tensor | None = None
):
<<<<<<< HEAD
# Gradient calculations are temporarily disabled for copy or init
with torch.no_grad():
if lora_a is None:
Expand All @@ -194,6 +196,18 @@ def init_weight(

# lora_b is zero so do not need to init it
if lora_b is not None:
=======
if lora_a is None:
torch.nn.init.kaiming_normal_(self.lora_a_, a=math.sqrt(5))
else:
# Gradient calculations are temporarily disabled
with torch.no_grad():
# In-place assignment
self.lora_a_.copy_(lora_a)

if lora_b is not None:
with torch.no_grad():
>>>>>>> 9477603 (fault tolenrence and restore)
self.lora_b_.copy_(lora_b)

@override
Expand Down
2 changes: 1 addition & 1 deletion mlora/server/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def post_task(request: Request):

task_conf = TASKCONFIG_CLASS[req["type"]](req, adapters, datasets)

logging.info(f"Create new task: {req["name"]} with adapter")
logging.info(f"Create new task: {req['name']} with adapter")

# set the task's state
req["state"] = "UNK"
Expand Down

0 comments on commit 4910946

Please sign in to comment.