diff --git a/aspen/model.py b/aspen/model.py index a89dd4af..60dfa454 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -1,6 +1,5 @@ -from aspen.modelargs import LlamaModelArgs, MultiLoraBatchData +from aspen import LlamaModelArgs, MultiLoraBatchData -import time import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -18,13 +17,13 @@ def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10 emb = torch.outer(seq, angles).float() emb = einops.repeat(emb, "... n -> ... (n r)", r=2) # cos(angle), sin(angle) - return (emb.cos().to(torch.float16), emb.sin().to(torch.float16)) + return (emb.cos().to(torch.float32), emb.sin().to(torch.float32)) def precompute_mask(input: MultiLoraBatchData, n_head: int, device: str) -> torch.Tensor: mask = torch.full((len(input.prompts_), n_head, input.batch_seq_len_, input.batch_seq_len_), float("-inf")) - mask = torch.triu(mask, diagonal=1).to(torch.float16).cuda(device) + mask = torch.triu(mask, diagonal=1).to(torch.float32).cuda(device) for idx, _ in enumerate(input.prompts_): zero_len = input.tokens_len_without_pad_[idx] @@ -36,7 +35,7 @@ def precompute_mask(input: MultiLoraBatchData, n_head: int, device: str) -> torc mask[idx] += torch.tensor([float("-inf")] * inf_len + [0] * zero_len).expand( input.batch_seq_len_, input.batch_seq_len_).cuda(device) - return mask.to(torch.float16) + return mask.to(torch.float32) def rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -61,7 +60,7 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, angle: Tuple[torch.Tens class RMSNorm(): def __init__(self, weight: torch.Tensor, eps: float = 1e-06): self.norm_eps_ = eps - self.weight_ = weight + self.weight_ = weight.to(torch.float32) def _norm(self, data: torch.Tensor) -> torch.Tensor: return data * torch.rsqrt(data.pow(2).mean(-1, keepdim=True) + self.norm_eps_) @@ -70,6 +69,32 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: return self._norm(data.float()).type_as(data) * self.weight_ +class Lora(): + def __init__(self, adapter_name: str): + self.adapter_name_: str = adapter_name + + self.lora_a_: torch.Tensor = None + self.lora_b_: torch.Tensor = None + + self.r_: int = 0 + self.alpha_: int = 0 + self.dropout_: float = 0.0 + self.scaling_: float = 0.0 + + def set_parameter(self, r: int, alpha: int, dropout: float): + self.r_ = r + self.alpha_ = alpha + self.dropout_ = dropout + self.scaling_ = alpha / r + + def forward(self, data: torch.Tensor) -> torch.Tensor: + data_ = F.dropout(data, self.dropout_) + data_ @= self.lora_a_.transpose(0, 1) + data_ @= self.lora_b_.transpose(0, 1) + data_ *= self.scaling_ + return data_ + + class Linear(): def __init__(self, weight: torch.Tensor): row, col = weight.shape @@ -80,31 +105,25 @@ def __init__(self, weight: torch.Tensor): self.use_adapter_: bool = False # adapter list self.adapter_names_: Set[str] = set() - # lora weight - self.lora_a_: Dict[str, torch.Tensor] = {} # r * dim - self.lora_b_: Dict[str, torch.Tensor] = {} # dim * r - # common paramas - self.lora_dropout_: Dict[str, float] = {} - self.r_: Dict[str, int] = {} - self.lora_alpha_: Dict[str, int] = {} - self.scaling_: Dict[str, float] = {} - - def update_layer(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): - if len(self.adapter_names_) <= 0: + self.loras_: Dict[str, Lora] = {} + + def set_lora_layer_parameter(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): + if len(self.adapter_names_) <= 0 or not self.use_adapter_: return - self.r_[adapter_name] = r - self.lora_alpha_[adapter_name] = lora_alpha - self.lora_dropout_[adapter_name] = lora_dropout - self.scaling_[adapter_name] = lora_alpha / r + self.loras_[adapter_name].set_parameter(r, lora_alpha, lora_dropout) + + def set_lora_layer_weight(self, adapter_name: str, lora_name: str, weight: torch.Tensor): + if adapter_name not in self.loras_: + self.loras_[adapter_name] = Lora(adapter_name) - def update_lora_weight(self, adapter_name: str, lora_name: str, weight: torch.Tensor): if lora_name == "lora_A": - self.lora_a_[adapter_name] = weight + self.loras_[adapter_name].lora_a_ = weight elif lora_name == "lora_B": - self.lora_b_[adapter_name] = weight + self.loras_[adapter_name].lora_b_ = weight else: raise (f"No lora_name {lora_name}") + self.adapter_names_.add(adapter_name) def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor: @@ -123,12 +142,8 @@ def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.T if adapter_name == "": continue - data_ = F.dropout(data[start_idx: end_idx], - self.lora_dropout_[adapter_name]) - data_ @= self.lora_a_[adapter_name].transpose(0, 1) - data_ @= self.lora_b_[adapter_name].transpose(0, 1) - data_ *= self.scaling_[adapter_name] - result[start_idx: end_idx] += data_ + result[start_idx: end_idx] += self.loras_[ + adapter_name].forward(data[start_idx:end_idx]) return result @@ -154,14 +169,12 @@ def __init__(self, layer_id: int, args: LlamaModelArgs): self.n_heads_ = args.n_heads_ self.head_dim_ = args.dim_ // args.n_heads_ - def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): - self.wk_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.wq_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.wv_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.wo_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.w1_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.w2_.update_layer(adapter_name, r, lora_alpha, lora_dropout) - self.w3_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + def set_lora_parameter(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): + linear_layer_list = [self.wk_, self.wq_, self.wv_, + self.wo_, self.w1_, self.w2_, self.w3_] + for linear_layer in linear_layer_list: + linear_layer.set_lora_layer_parameter( + adapter_name, r, lora_alpha, lora_dropout) # @torch.compile def forward(self, data: torch.Tensor, mask: torch.Tensor, rope_angle: Tuple[torch.Tensor, torch.Tensor], input_args: MultiLoraBatchData): @@ -229,8 +242,8 @@ def __init__(self, args: LlamaModelArgs): self.dim_ = args.dim_ def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): - for layer in self.layers_: - layer.update_lora_configure( + for transformer_layer in self.layers_: + transformer_layer.set_lora_parameter( adapter_name, r, lora_alpha, lora_dropout) def forward(self, input: MultiLoraBatchData): @@ -256,19 +269,16 @@ def forward_for_checkpoint(*inputs): def get_train_paramas(self, config: Dict[str, str]) -> List[int]: train_paramas = [] - for layer in self.layers_: + for transformer_layer in self.layers_: for lora_config in config["lora"]: adapter_name = lora_config["name"] - if adapter_name in layer.wq_.lora_a_: - train_paramas.append(layer.wq_.lora_a_[adapter_name]) - train_paramas.append(layer.wq_.lora_b_[adapter_name]) - if adapter_name in layer.wk_.lora_a_: - train_paramas.append(layer.wk_.lora_a_[adapter_name]) - train_paramas.append(layer.wk_.lora_b_[adapter_name]) - if adapter_name in layer.wv_.lora_a_: - train_paramas.append(layer.wv_.lora_a_[adapter_name]) - train_paramas.append(layer.wv_.lora_b_[adapter_name]) - if adapter_name in layer.wo_.lora_a_: - train_paramas.append(layer.wo_.lora_a_[adapter_name]) - train_paramas.append(layer.wo_.lora_b_[adapter_name]) + lora_layer_list = [transformer_layer.wq_.loras_, transformer_layer.wk_.loras_, + transformer_layer.wv_.loras_, transformer_layer.wo_.loras_, + transformer_layer.w1_.loras_, transformer_layer.w2_.loras_, + transformer_layer.w3_.loras_] + + for lora_layer in lora_layer_list: + if adapter_name in lora_layer: + train_paramas.append(lora_layer[adapter_name].lora_a_) + train_paramas.append(lora_layer[adapter_name].lora_b_) return train_paramas diff --git a/aspen/modelloader.py b/aspen/modelloader.py index c85af0b2..5b691617 100644 --- a/aspen/modelloader.py +++ b/aspen/modelloader.py @@ -42,7 +42,7 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): elif "norm.weight" in layer_name: model.norm_ = RMSNorm(w, model.norm_eps_) elif "output.weight" in layer_name: - model.output_ = w + model.output_ = w.to(torch.float32) else: print(f"Not use layer {layer_name}.", file=sys.stderr) @@ -86,123 +86,47 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): elif "norm.weight" in layer_name: model.norm_ = RMSNorm(w, model.norm_eps_) elif "lm_head.weight" in layer_name: - model.output_ = w + model.output_ = w.to(torch.float32) else: print(f"Not use layer {layer_name}.", file=sys.stderr) -def load_alpaca_lora_7b_weight(model: LlamaModel, lora_model_path: str, adapter_name: str, device: str): - lora_weight = torch.load( - lora_model_path, map_location=torch.device(device)) - for layer_name in lora_weight: - w: torch.Tensor = lora_weight[layer_name].to(torch.float16) - w.requires_grad_(True) - - layer_name = layer_name[len("base_model.model.model.layers."):] - layer_id = int(layer_name[:layer_name.find(".")]) - lora_name = "" - if "lora_A" in layer_name: - lora_name = "lora_A" - elif "lora_B" in layer_name: - lora_name = "lora_B" - - if "q_proj" in layer_name: - model.layers_[layer_id].wq_.update_lora_weight( - adapter_name, lora_name, w) - model.layers_[layer_id].wq_.use_adapter_ = True - elif "k_proj" in layer_name: - model.layers_[layer_id].wk_.update_lora_weight( - adapter_name, lora_name, w) - model.layers_[layer_id].wk_.use_adapter_ = True - elif "v_proj" in layer_name: - model.layers_[layer_id].wv_.update_lora_weight( - adapter_name, lora_name, w) - model.layers_[layer_id].wv_.use_adapter_ = True - elif "o_proj" in layer_name: - model.layers_[layer_id].wo_.update_lora_weight( - adapter_name, lora_name, w) - model.layers_[layer_id].wo_.use_adapter_ = True - else: - print(f"Not user layer {layer_name}") - - def load_random_lora_7b_weight(model: LlamaModel, adapter_name: str, r: int, dim: int, target_module: str, device: str) -> None: norm_mean = 0 norm_std = 1e-3 - for layer in model.layers_: - if target_module["q_proj"] is True: - wq_lora_a_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) - wq_lora_b_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) - layer.wq_.update_lora_weight( - adapter_name, "lora_A", wq_lora_a_weight) - layer.wq_.update_lora_weight( - adapter_name, "lora_B", wq_lora_b_weight) - layer.wq_.use_adapter_ = True - - if target_module["k_proj"] is True: - wk_lora_a_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) - wk_lora_b_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) - layer.wk_.update_lora_weight( - adapter_name, "lora_A", wk_lora_a_weight) - layer.wk_.update_lora_weight( - adapter_name, "lora_B", wk_lora_b_weight) - layer.wk_.use_adapter_ = True - - if target_module["v_proj"] is True: - wv_lora_a_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) - wv_lora_b_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) - layer.wv_.update_lora_weight( - adapter_name, "lora_A", wv_lora_a_weight) - layer.wv_.update_lora_weight( - adapter_name, "lora_B", wv_lora_b_weight) - layer.wv_.use_adapter_ = True - - if target_module["o_proj"] is True: - wo_lora_a_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) - wo_lora_b_weight = torch.normal( - mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) - layer.wo_.update_lora_weight( - adapter_name, "lora_A", wo_lora_a_weight) - layer.wo_.update_lora_weight( - adapter_name, "lora_B", wo_lora_b_weight) - layer.wo_.use_adapter_ = True + target_module_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"] + for transformer_layer in model.layers_: + target_layer_list = [transformer_layer.wq_, transformer_layer.wk_, + transformer_layer.wv_, transformer_layer.wo_, + transformer_layer.w1_, transformer_layer.w2_, + transformer_layer.w3_] + for idx, module_name in enumerate(target_module_name_list): + if module_name in target_module and target_module[module_name]: + lora_a_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float32) + lora_b_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float32) + target_layer_list[idx].set_lora_layer_weight( + adapter_name, "lora_A", lora_a_weight) + target_layer_list[idx].set_lora_layer_weight( + adapter_name, "lora_B", lora_b_weight) def save_lora_model(model: LlamaModel, path: str, lora_name: str): lora_weight_dict = {} - for idx, layer in enumerate(model.layers_): + for idx, transformer_layer in enumerate(model.layers_): layer_prefix_name = "base_model.model.model.layers." + \ str(idx) + "." + "self_attn." - if lora_name in layer.wq_.lora_a_: - lora_weight_dict[layer_prefix_name + - "q_proj.lora_A.weight"] = layer.wq_.lora_a_[lora_name] - if lora_name in layer.wq_.lora_b_: - lora_weight_dict[layer_prefix_name + - "q_proj.lora_B.weight"] = layer.wq_.lora_b_[lora_name] - if lora_name in layer.wk_.lora_a_: - lora_weight_dict[layer_prefix_name + - "k_proj.lora_A.weigth"] = layer.wk_.lora_a_[lora_name] - if lora_name in layer.wk_.lora_b_: - lora_weight_dict[layer_prefix_name + - "k_proj.lora_B.weight"] = layer.wk_.lora_b_[lora_name] - if lora_name in layer.wv_.lora_a_: - lora_weight_dict[layer_prefix_name + - "v_proj.lora_A.weight"] = layer.wv_.lora_a_[lora_name] - if lora_name in layer.wv_.lora_b_: - lora_weight_dict[layer_prefix_name + - "v_proj.lora_B.weight"] = layer.wv_.lora_b_[lora_name] - if lora_name in layer.wo_.lora_a_: - lora_weight_dict[layer_prefix_name + - "o_proj.lora_A.weight"] = layer.wo_.lora_a_[lora_name] - if lora_name in layer.wo_.lora_b_: - lora_weight_dict[layer_prefix_name + - "o_proj.lora_B.weight"] = layer.wo_.lora_b_[lora_name] + lora_layer_list = [transformer_layer.wq_, transformer_layer.wk_, + transformer_layer.wv_, transformer_layer.wo_, + transformer_layer.w1_, transformer_layer.w2_, + transformer_layer.w3_] + lora_layer_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"] + for idx, lora_layer in enumerate(lora_layer_list): + if lora_name in lora_layer.loras_: + lora_weight_dict[layer_prefix_name + + f"{lora_layer_name_list[idx]}.lora_A.weight"] = lora_layer.loras_[lora_name].lora_a_ + lora_weight_dict[layer_prefix_name + + f"{lora_layer_name_list[idx]}.lora_B.weight"] = lora_layer.loras_[lora_name].lora_b_ torch.save(lora_weight_dict, path) diff --git a/config/lora.json b/config/lora.json index db688301..c73b3811 100644 --- a/config/lora.json +++ b/config/lora.json @@ -1,10 +1,11 @@ { - "base_model": "/yezhengmao/modules/llama-7b/7B/consolidated.00.pth", - "token_model": "/yezhengmao/modules/llama-7b/tokenizer.model", + "base_model": "", + "token_model": "", "cutoff_len": 512, "group_by_length": false, "expand_right": true, "device": "cuda:1", + "save_step": 200, "lora": [ { "name": "lora_0", @@ -20,9 +21,9 @@ "v_proj": true, "o_proj": true }, - "data": "data/train_lora_a.json", - "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n", - "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n" + "data": "", + "prompt_input": "", + "prompt_no_input": "" }, { "name": "lora_1", @@ -38,9 +39,9 @@ "v_proj": true, "o_proj": true }, - "data": "data/train_lora_b.json", - "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n", - "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n" + "data": "", + "prompt_input": "", + "prompt_no_input": "" } ] } \ No newline at end of file diff --git a/mlora.py b/mlora.py index 9a33a769..486ef3f0 100644 --- a/mlora.py +++ b/mlora.py @@ -1,9 +1,10 @@ -import json -import torch from aspen import LlamaModel, Tokenizer, DataSet from aspen import LlamaModelArgs, MultiLoraBatchData from aspen import load_llama_7b_weight, load_random_lora_7b_weight from aspen import save_lora_model + +import json +import torch import torch.optim with open('config/lora.json', 'r', encoding='utf8') as fp: @@ -47,19 +48,15 @@ def init_lora_model(llama_model: LlamaModel): torch.cuda.empty_cache() - # optim begin - optimizer = torch.optim.SGD( - llama_model.get_train_paramas(config), lr=1e-3) - # optim end + optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) - step = 0 - # torch.autograd.set_detect_anomaly(True) + step_cnt = 0 while not data_set.check_done(): optimizer.zero_grad() loss_fn = torch.nn.CrossEntropyLoss() input: MultiLoraBatchData = data_set.get_batch_data() - step += 1 + step_cnt += 1 output = llama_model.forward(input) labels = torch.tensor(input.batch_tokens_, @@ -84,10 +81,10 @@ def init_lora_model(llama_model: LlamaModel): total_loss.backward() optimizer.step() - if step % 200 == 0: + if step_cnt % config["save_step"] == 0: for lora_config in config["lora"]: save_lora_model( - llama_model, lora_config["output"] + f".chk{step}", lora_config["name"]) + llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) for lora_config in config["lora"]: save_lora_model(