From 92020168139a915458ed4cb63e9430d26e98f640 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Mon, 28 Aug 2023 14:36:57 +0800 Subject: [PATCH 1/2] Upload original codes --- .gitignore | 8 ++ README.md | 6 +- aspen/__init__.py | 26 ++++ aspen/dataset.py | 146 +++++++++++++++++++++++ aspen/main.py | 1 - aspen/model.py | 274 +++++++++++++++++++++++++++++++++++++++++++ aspen/modelargs.py | 43 +++++++ aspen/modelloader.py | 208 ++++++++++++++++++++++++++++++++ aspen/tokenizer.py | 30 +++++ aspen/utils.py | 8 ++ config/lora.json | 46 ++++++++ data/demo.json | 1 - mlora.py | 94 +++++++++++++++ requirements.txt | 7 +- setup.py | 5 + 15 files changed, 894 insertions(+), 9 deletions(-) create mode 100644 aspen/__init__.py create mode 100644 aspen/dataset.py delete mode 100644 aspen/main.py create mode 100644 aspen/model.py create mode 100644 aspen/modelargs.py create mode 100644 aspen/modelloader.py create mode 100644 aspen/tokenizer.py create mode 100644 aspen/utils.py create mode 100644 config/lora.json delete mode 100644 data/demo.json create mode 100644 mlora.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 68bc17f9..f52bf300 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,11 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# IDEs +.vscode/ + +# ASPEN +__pycache__/ +*.egg-info/ +*.egg \ No newline at end of file diff --git a/README.md b/README.md index de90a2bf..190c7c47 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # ASPEN: Efficient Multi-LoRA Fine Tuning with Shared-Based Model +[![Python application](https://github.com/TUDB-Labs/multi-lora-fine-tune/actions/workflows/python-app.yml/badge.svg)](https://github.com/TUDB-Labs/multi-lora-fine-tune/actions/workflows/python-app.yml) This repository provides tools for fine-tuning large language models (LLMs) using the LoRA or QLoRA methods more efficiently. It provides the framework to support multiple LoRA/qLoRA models fine tunning at the same time. By reusing the shared frozen-based model, we provide the framework to reduce GPU Memory usage for multiple fine-tuning models. @@ -41,7 +42,7 @@ Submit a pull request with a detailed explanation of your changes. ```bibtex @article{multi-lora, title={Aspen: Efficient Finetuning of Multiple Lora and QLora}, - author={zhenmao ye*, dengchen li*, tingfen lan, zhaoyi liu, jie zuo, lei duan, mingjie tang}, + author={zhengmao ye*, dengchun li*, tingfeng lan, zhaoyi liu, jie zuo, lei duan, mingjie tang}, journal={arXiv preprint arXiv:xxxx}, year={2023} } @@ -49,6 +50,3 @@ Submit a pull request with a detailed explanation of your changes. ## License This project is licensed under the Apache 2.0 License - see the LICENSE file for details - - - diff --git a/aspen/__init__.py b/aspen/__init__.py new file mode 100644 index 00000000..3c114ab7 --- /dev/null +++ b/aspen/__init__.py @@ -0,0 +1,26 @@ +from aspen.utils import convert_hf_to_pth +from aspen.tokenizer import Tokenizer +from aspen.model import LlamaModel, Linear, RMSNorm +from aspen.modelargs import TokenizerArgs, LlamaModelArgs, MultiLoraBatchData, LoraBatchDataConfig +from aspen.dataset import DataSet +from aspen.modelloader import load_llama_7b_weight, load_llama_tf_weight +from aspen.modelloader import load_alpaca_lora_7b_weight, load_random_lora_7b_weight +from aspen.modelloader import save_lora_model + +__all__ = [ + "Tokenizer", + "LlamaModel", + "Linear", + "RMSNorm", + "TokenizerArgs", + "LlamaModelArgs", + "MultiLoraBatchData", + "LoraBatchDataConfig", + "DataSet", + "convert_hf_to_pth", + "load_llama_7b_weight", + "load_llama_tf_weight", + "load_alpaca_lora_7b_weight", + "load_random_lora_7b_weight", + "save_lora_model" +] diff --git a/aspen/dataset.py b/aspen/dataset.py new file mode 100644 index 00000000..a0fd00ad --- /dev/null +++ b/aspen/dataset.py @@ -0,0 +1,146 @@ +import math +import json +import random +from aspen import MultiLoraBatchData, LoraBatchDataConfig, Tokenizer +from typing import Dict, List, Tuple + + +class DataSet(): + config_ = None + # Dict[lora_name, ] + lora_token_data_: Dict[str, List[Tuple[str, List[int]]]] = None + + lora_num_epochs_: Dict[str, int] = {} + lora_cnt_epochs_: Dict[str, int] = {} + lora_start_idx_: Dict[str, int] = {} + lora_batch_size_: Dict[str, int] = {} + + def __get_lora_text_data(self) -> Dict[str, List[str]]: + lora_text_data = {} + for lora_config in self.config_["lora"]: + lora_name = lora_config["name"] + data_path = lora_config["data"] + lora_text_data[lora_name] = [] + self.lora_cnt_epochs_[lora_name] = 0 + self.lora_start_idx_[lora_name] = 0 + self.lora_batch_size_[lora_name] = lora_config["batch_size"] + self.lora_num_epochs_[lora_name] = lora_config["num_epochs"] + + with open(data_path, 'r', encoding='utf8') as fp: + for raw_data in json.load(fp): + raw_data_input = raw_data["input"] + raw_data_output = raw_data["output"] + raw_data_instruction = raw_data["instruction"] + text_data = "" + if raw_data_input is None or len(raw_data_input) <= 1: + text_data = lora_config["prompt_no_input"].replace( + "{output}", raw_data_output).replace("{instruction}", raw_data_instruction) + else: + text_data = lora_config["prompt_input"].replace( + "{output}", raw_data_output).replace("{instruction}", raw_data_instruction).replace("{input}", raw_data_input) + lora_text_data[lora_name].append(text_data) + + return lora_text_data + + def __init__(self, config: Dict[str, str], tokenizer: Tokenizer): + self.config_ = config + self.tokenizer_: Tokenizer = tokenizer + + print("to load text data from file.") + lora_text_data = self.__get_lora_text_data() + print("load text data from file done.") + + # Dict[lora_name, ] + self.lora_token_data_: Dict[str, List[Tuple[str, List[int]]]] = {} + + print("to encode text data to tokens") + for lora_name in lora_text_data: + self.lora_token_data_[lora_name] = [] + + for idx, text in enumerate(lora_text_data[lora_name]): + tokens = tokenizer.encode(text, bos=True, eos=True) + if len(tokens) > config["cutoff_len"]: + tokens = tokens[:config["cutoff_len"]] + self.lora_token_data_[lora_name].append((text, tokens)) + if idx % 10000 == 0: + print( + f"encode text data: {idx}/{len(lora_text_data[lora_name])}") + # group by length + if self.config_["group_by_length"]: + self.lora_token_data_[lora_name].sort( + key=lambda x: len(x[1]), reverse=True) + else: + random.shuffle(self.lora_token_data_[lora_name]) + print("encode text data to tokens done.") + + def check_done(self) -> bool: + for lora_name in self.lora_token_data_: + if self.lora_cnt_epochs_[lora_name] < self.lora_num_epochs_[lora_name]: + return False + return True + + def get_batch_data(self) -> MultiLoraBatchData: + prompts_list: List[str] = [] + batch_tokens_list: List[List[int]] = [] + + prompts_batch_config_list: List[LoraBatchDataConfig] = [] + + tokens_without_pad_len_list: List[int] = [] + + max_token_len = 0 + + batch_start_idx = 0 + + for lora_name in self.lora_token_data_: + if self.lora_cnt_epochs_[lora_name] >= self.lora_num_epochs_[lora_name]: + continue + start_idx = self.lora_start_idx_[lora_name] + end_idx = start_idx + self.lora_batch_size_[lora_name] + prompt_and_tokens_list = self.lora_token_data_[ + lora_name][start_idx:end_idx] + + for pt in prompt_and_tokens_list: + prompt, token = pt + prompts_list.append(prompt) + batch_tokens_list.append(token.copy()) + + max_token_len = max(max_token_len, len(token)) + tokens_without_pad_len_list.append(len(token)) + + lora_config = LoraBatchDataConfig(adapter_name_=lora_name, batch_start_idx_=batch_start_idx, + batch_end_idx_=batch_start_idx + len(prompt_and_tokens_list)) + batch_start_idx += len(prompt_and_tokens_list) + prompts_batch_config_list.append(lora_config) + + self.lora_start_idx_[lora_name] += self.lora_batch_size_[lora_name] + if self.lora_start_idx_[lora_name] >= len(self.lora_token_data_[lora_name]): + self.lora_start_idx_[lora_name] = 0 + self.lora_cnt_epochs_[lora_name] += 1 + + print(f"{lora_name} train data:") + print( + f" epoch: {self.lora_cnt_epochs_[lora_name] + 1} / {self.lora_num_epochs_[lora_name]}") + print( + f" step : {self.lora_start_idx_[lora_name]} / {len(self.lora_token_data_[lora_name])}") + print( + f" : {(self.lora_cnt_epochs_[lora_name] * len(self.lora_token_data_[lora_name]) + self.lora_start_idx_[lora_name]) * 100 / (self.lora_num_epochs_[lora_name] * len(self.lora_token_data_[lora_name]))}%") + + # align batch data + max_token_len = math.ceil(max_token_len / 8) * 8 + + for tokens in batch_tokens_list: + while len(tokens) < max_token_len: + if self.config_["expand_right"]: + tokens.append(self.tokenizer_.pad_id_) + else: + tokens.insert(0, self.tokenizer_.pad_id_) + + print( + f"batch data size: {max_token_len} * {len(batch_tokens_list)}") + + return MultiLoraBatchData(prompts_=prompts_list, + lora_batch_data_config_=prompts_batch_config_list, + batch_seq_len_=max_token_len, + expand_right_=self.config_["expand_right"], + batch_tokens_=batch_tokens_list, + tokens_len_without_pad_=tokens_without_pad_len_list) diff --git a/aspen/main.py b/aspen/main.py deleted file mode 100644 index 8b137891..00000000 --- a/aspen/main.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/aspen/model.py b/aspen/model.py new file mode 100644 index 00000000..a89dd4af --- /dev/null +++ b/aspen/model.py @@ -0,0 +1,274 @@ +from aspen.modelargs import LlamaModelArgs, MultiLoraBatchData + +import time +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +import xformers.ops +import xformers.ops.fmha.attn_bias +from typing import List, Dict, Set, Tuple +from bitsandbytes.nn import Linear8bitLt, Int8Params + + +def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]: + angles = 1.0 / (theta ** (torch.arange(0, dim, 2).to(device) + [: (dim // 2)].to(torch.float) / dim)) + seq = torch.arange(seq_len, device=angles.device) + emb = torch.outer(seq, angles).float() + emb = einops.repeat(emb, "... n -> ... (n r)", r=2) + # cos(angle), sin(angle) + return (emb.cos().to(torch.float16), emb.sin().to(torch.float16)) + + +def precompute_mask(input: MultiLoraBatchData, n_head: int, device: str) -> torch.Tensor: + mask = torch.full((len(input.prompts_), n_head, + input.batch_seq_len_, input.batch_seq_len_), float("-inf")) + mask = torch.triu(mask, diagonal=1).to(torch.float16).cuda(device) + + for idx, _ in enumerate(input.prompts_): + zero_len = input.tokens_len_without_pad_[idx] + inf_len = input.batch_seq_len_ - zero_len + if input.expand_right_: + mask[idx] += torch.tensor([0] * zero_len + [float("-inf")] * inf_len).expand( + input.batch_seq_len_, input.batch_seq_len_).cuda(device) + else: + mask[idx] += torch.tensor([float("-inf")] * inf_len + [0] * zero_len).expand( + input.batch_seq_len_, input.batch_seq_len_).cuda(device) + + return mask.to(torch.float16) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = einops.rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return einops.rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, angle: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # data shape is: batch_size * max_seq_len * n_head * n_dim + _, max_seq_len, _, dim_head = xq.shape + + cos = angle[0][:max_seq_len].view(max_seq_len, 1, dim_head) + sin = angle[1][:max_seq_len].view(max_seq_len, 1, dim_head) + + xq = (xq * cos) + (rotate_half(xq) * sin) + xk = (xk * cos) + (rotate_half(xk) * sin) + return (xq, xk) + + +class RMSNorm(): + def __init__(self, weight: torch.Tensor, eps: float = 1e-06): + self.norm_eps_ = eps + self.weight_ = weight + + def _norm(self, data: torch.Tensor) -> torch.Tensor: + return data * torch.rsqrt(data.pow(2).mean(-1, keepdim=True) + self.norm_eps_) + + def forward(self, data: torch.Tensor) -> torch.Tensor: + return self._norm(data.float()).type_as(data) * self.weight_ + + +class Linear(): + def __init__(self, weight: torch.Tensor): + row, col = weight.shape + self.weight_ = Linear8bitLt( + input_features=col, output_features=row, bias=False, has_fp16_weights=False) + self.weight_.weight = Int8Params( + weight.data, requires_grad=False).cuda(weight.device) + self.use_adapter_: bool = False + # adapter list + self.adapter_names_: Set[str] = set() + # lora weight + self.lora_a_: Dict[str, torch.Tensor] = {} # r * dim + self.lora_b_: Dict[str, torch.Tensor] = {} # dim * r + # common paramas + self.lora_dropout_: Dict[str, float] = {} + self.r_: Dict[str, int] = {} + self.lora_alpha_: Dict[str, int] = {} + self.scaling_: Dict[str, float] = {} + + def update_layer(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): + if len(self.adapter_names_) <= 0: + return + + self.r_[adapter_name] = r + self.lora_alpha_[adapter_name] = lora_alpha + self.lora_dropout_[adapter_name] = lora_dropout + self.scaling_[adapter_name] = lora_alpha / r + + def update_lora_weight(self, adapter_name: str, lora_name: str, weight: torch.Tensor): + if lora_name == "lora_A": + self.lora_a_[adapter_name] = weight + elif lora_name == "lora_B": + self.lora_b_[adapter_name] = weight + else: + raise (f"No lora_name {lora_name}") + self.adapter_names_.add(adapter_name) + + def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor: + # data shape is: batch_size * max_seq_len * dim + # result = data @ self.weight_.transpose(0, 1) + result = self.weight_.forward(data) + + if not self.use_adapter_: + return result + + for lora_config in input_args.lora_batch_data_config_: + adapter_name = lora_config.adapter_name_ + start_idx = lora_config.batch_start_idx_ + end_idx = lora_config.batch_end_idx_ + + if adapter_name == "": + continue + + data_ = F.dropout(data[start_idx: end_idx], + self.lora_dropout_[adapter_name]) + data_ @= self.lora_a_[adapter_name].transpose(0, 1) + data_ @= self.lora_b_[adapter_name].transpose(0, 1) + data_ *= self.scaling_[adapter_name] + result[start_idx: end_idx] += data_ + + return result + + +class Transformer(): + def __init__(self, layer_id: int, args: LlamaModelArgs): + # attention + self.wq_: Linear = None # dim * dim + self.wk_: Linear = None # dim * dim + self.wv_: Linear = None # dim * dim + self.wo_: Linear = None # dim * dim + # feed forward + self.w1_: Linear = None # also gate FNN * dim + self.w2_: Linear = None # also down dim * FNN + self.w3_: Linear = None # also up FNN * dim + # for lora linear + # norm + self.attention_norm_: RMSNorm = None # dim + self.ffn_norm_: RMSNorm = None # dim + # other arg + self.layer_id_ = layer_id + self.norm_eps_ = args.norm_eps_ + self.n_heads_ = args.n_heads_ + self.head_dim_ = args.dim_ // args.n_heads_ + + def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): + self.wk_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.wq_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.wv_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.wo_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.w1_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.w2_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + self.w3_.update_layer(adapter_name, r, lora_alpha, lora_dropout) + + # @torch.compile + def forward(self, data: torch.Tensor, mask: torch.Tensor, rope_angle: Tuple[torch.Tensor, torch.Tensor], input_args: MultiLoraBatchData): + batch_size, max_seq_len, _ = data.shape + + attention_norm_data = self.attention_norm_.forward(data) + + xq = self.wq_.forward(attention_norm_data, input_args) + xk = self.wk_.forward(attention_norm_data, input_args) + xv = self.wv_.forward(attention_norm_data, input_args) + + # conver shape to multi head + xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_) + xk = xk.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_) + xv = xv.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_) + + # apply rotary embedding + xq, xk = apply_rotary_emb(xq, xk, rope_angle) + + # score shape is: batch_size * n_head * seq_len * dim_head + # convert shape to: batch_size * seq_len * dim + # attention_score = attention_score.transpose( + # 1, 2).contiguous().view(batch_size, max_seq_len, -1) + # attention_score = flash_attn_func(xq, xk, xv, causal=True) + # attention_score = attention_score.view(batch_size, max_seq_len, -1) + attention_score = xformers.ops.memory_efficient_attention( + xq, xk, xv, mask) + attention_score = attention_score.view(batch_size, max_seq_len, -1) + + # get output attention score + data = data + self.wo_.forward(attention_score, input_args) + + # feed forward fully connected + score_norm_data = self.ffn_norm_.forward(data) + w1 = self.w1_.forward(score_norm_data, input_args) + w3 = self.w3_.forward(score_norm_data, input_args) + + data = data + self.w2_.forward(F.silu(w1) * w3, input_args) + + return data + + +class LlamaModel(): + def __init__(self, args: LlamaModelArgs): + # weight + self.token_embedding_: torch.Tensor = None + + self.layers_: List[Transformer] = [] + for layer_id in range(args.n_layers_): + self.layers_.append(Transformer(layer_id, args)) + + self.norm_: RMSNorm = None # dim + self.output_: torch.Tensor = None # vocab size * dim + + # cos and sin + self.rope_angle_: Tuple[torch.Tensor, torch.Tensor] = precompute_rope_angle( + args.dim_ // args.n_heads_, args.max_seq_len_, args.device) + + self.norm_eps_ = args.norm_eps_ + + self.device_ = args.device + self.n_heads_ = args.n_heads_ + self.vocab_size_ = args.vocab_size_ + self.pad_id_ = args.pad_id_ + self.dim_ = args.dim_ + + def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float): + for layer in self.layers_: + layer.update_lora_configure( + adapter_name, r, lora_alpha, lora_dropout) + + def forward(self, input: MultiLoraBatchData): + tokens = torch.tensor(input.batch_tokens_, + dtype=torch.int).to(self.device_) + data = F.embedding(tokens, self.token_embedding_, + padding_idx=self.pad_id_).requires_grad_(True) + mask = precompute_mask(input, self.n_heads_, self.device_) + + def create_forward_for_checkpoint(module: Transformer): + def forward_for_checkpoint(*inputs): + return module.forward(*inputs) + return forward_for_checkpoint + + for layer in self.layers_: + data = torch.utils.checkpoint.checkpoint( + create_forward_for_checkpoint(layer), data, mask, self.rope_angle_, input) + + data = self.norm_.forward(data) + data @= self.output_.transpose(0, 1) + + return data + + def get_train_paramas(self, config: Dict[str, str]) -> List[int]: + train_paramas = [] + for layer in self.layers_: + for lora_config in config["lora"]: + adapter_name = lora_config["name"] + if adapter_name in layer.wq_.lora_a_: + train_paramas.append(layer.wq_.lora_a_[adapter_name]) + train_paramas.append(layer.wq_.lora_b_[adapter_name]) + if adapter_name in layer.wk_.lora_a_: + train_paramas.append(layer.wk_.lora_a_[adapter_name]) + train_paramas.append(layer.wk_.lora_b_[adapter_name]) + if adapter_name in layer.wv_.lora_a_: + train_paramas.append(layer.wv_.lora_a_[adapter_name]) + train_paramas.append(layer.wv_.lora_b_[adapter_name]) + if adapter_name in layer.wo_.lora_a_: + train_paramas.append(layer.wo_.lora_a_[adapter_name]) + train_paramas.append(layer.wo_.lora_b_[adapter_name]) + return train_paramas diff --git a/aspen/modelargs.py b/aspen/modelargs.py new file mode 100644 index 00000000..7e76ea5b --- /dev/null +++ b/aspen/modelargs.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import List, Dict + + +@dataclass +class TokenizerArgs: + vocab_size_: int = -1 + bos_id_: int = -1 + eos_id_: int = -1 + pad_id_: int = -1 + + +@dataclass +class LlamaModelArgs: + dim_: int = 4096 + multiple_of_: int = 256 + n_heads_: int = 32 + n_layers_: int = 32 + norm_eps_: float = 1e-06 + vocab_size_: int = -1 + pad_id_: int = -1 + max_seq_len_: int = 2048 + device: str = "" + + +@dataclass +class LoraBatchDataConfig: + adapter_name_: str = "" + batch_start_idx_: int = -1 + batch_end_idx_: int = -1 + + +@dataclass +class MultiLoraBatchData: + prompts_: List[str] = None + lora_batch_data_config_: List[LoraBatchDataConfig] = None + + # batch seq len + batch_seq_len_: int = None + expand_right_: int = True + + batch_tokens_: List[List[int]] = None + tokens_len_without_pad_: List[int] = None diff --git a/aspen/modelloader.py b/aspen/modelloader.py new file mode 100644 index 00000000..c85af0b2 --- /dev/null +++ b/aspen/modelloader.py @@ -0,0 +1,208 @@ +import sys +import torch + +from aspen import LlamaModel, Linear, RMSNorm +from transformers import LlamaForCausalLM + + +def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): + weight = torch.load(llama_model_path, map_location=torch.device(device)) + + for layer_name in weight: + w: torch.Tensor = weight[layer_name] + w.requires_grad_(False) + + if "layers" in layer_name: + layer_name = layer_name[len("layers."):] + layer_id = int(layer_name[:layer_name.find(".")]) + if "wq" in layer_name: + model.layers_[layer_id].wq_ = Linear(w) + elif "wk" in layer_name: + model.layers_[layer_id].wk_ = Linear(w) + elif "wv" in layer_name: + model.layers_[layer_id].wv_ = Linear(w) + elif "wo" in layer_name: + model.layers_[layer_id].wo_ = Linear(w) + elif "w1" in layer_name: + model.layers_[layer_id].w1_ = Linear(w) + elif "w2" in layer_name: + model.layers_[layer_id].w2_ = Linear(w) + elif "w3" in layer_name: + model.layers_[layer_id].w3_ = Linear(w) + elif "attention_norm" in layer_name: + model.layers_[layer_id].attention_norm_ = RMSNorm( + w, model.norm_eps_) + elif "ffn_norm" in layer_name: + model.layers_[layer_id].ffn_norm_ = RMSNorm( + w, model.norm_eps_) + else: + print(f"Not use layer {layer_name}.", file=sys.stderr) + elif "tok_embeddings" in layer_name: + model.token_embedding_ = w + elif "norm.weight" in layer_name: + model.norm_ = RMSNorm(w, model.norm_eps_) + elif "output.weight" in layer_name: + model.output_ = w + else: + print(f"Not use layer {layer_name}.", file=sys.stderr) + + +def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): + weight = LlamaForCausalLM.from_pretrained( + llama_model_path, device_map=dev).state_dict() + + for layer_name in weight: + w: torch.Tensor = weight[layer_name] + w.requires_grad_(False) + + if "model.layers" in layer_name: + layer_name = layer_name[len("model.layers."):] + layer_id = int(layer_name[:layer_name.find(".")]) + if "self_attn.q_proj" in layer_name: + model.layers_[layer_id].wq_ = Linear(w) + elif "self_attn.k_proj" in layer_name: + model.layers_[layer_id].wk_ = Linear(w) + elif "self_attn.v_proj" in layer_name: + model.layers_[layer_id].wv_ = Linear(w) + elif "self_attn.o_proj" in layer_name: + model.layers_[layer_id].wo_ = Linear(w) + elif "mlp.gate_proj" in layer_name: + model.layers_[layer_id].w1_ = Linear(w) + elif "mlp.down_proj" in layer_name: + model.layers_[layer_id].w2_ = Linear(w) + elif "mlp.up_proj" in layer_name: + model.layers_[layer_id].w3_ = Linear(w) + elif "input_layernorm" in layer_name: + model.layers_[layer_id].attention_norm_ = RMSNorm( + w, model.norm_eps_) + elif "post_attention_layernorm" in layer_name: + model.layers_[layer_id].ffn_norm_ = RMSNorm( + w, model.norm_eps_) + else: + print( + f"Not use layer model.layers.{layer_name}.", file=sys.stderr) + elif "embed_tokens" in layer_name: + model.token_embedding_ = w + elif "norm.weight" in layer_name: + model.norm_ = RMSNorm(w, model.norm_eps_) + elif "lm_head.weight" in layer_name: + model.output_ = w + else: + print(f"Not use layer {layer_name}.", file=sys.stderr) + + +def load_alpaca_lora_7b_weight(model: LlamaModel, lora_model_path: str, adapter_name: str, device: str): + lora_weight = torch.load( + lora_model_path, map_location=torch.device(device)) + for layer_name in lora_weight: + w: torch.Tensor = lora_weight[layer_name].to(torch.float16) + w.requires_grad_(True) + + layer_name = layer_name[len("base_model.model.model.layers."):] + layer_id = int(layer_name[:layer_name.find(".")]) + lora_name = "" + if "lora_A" in layer_name: + lora_name = "lora_A" + elif "lora_B" in layer_name: + lora_name = "lora_B" + + if "q_proj" in layer_name: + model.layers_[layer_id].wq_.update_lora_weight( + adapter_name, lora_name, w) + model.layers_[layer_id].wq_.use_adapter_ = True + elif "k_proj" in layer_name: + model.layers_[layer_id].wk_.update_lora_weight( + adapter_name, lora_name, w) + model.layers_[layer_id].wk_.use_adapter_ = True + elif "v_proj" in layer_name: + model.layers_[layer_id].wv_.update_lora_weight( + adapter_name, lora_name, w) + model.layers_[layer_id].wv_.use_adapter_ = True + elif "o_proj" in layer_name: + model.layers_[layer_id].wo_.update_lora_weight( + adapter_name, lora_name, w) + model.layers_[layer_id].wo_.use_adapter_ = True + else: + print(f"Not user layer {layer_name}") + + +def load_random_lora_7b_weight(model: LlamaModel, adapter_name: str, r: int, dim: int, target_module: str, device: str) -> None: + norm_mean = 0 + norm_std = 1e-3 + for layer in model.layers_: + if target_module["q_proj"] is True: + wq_lora_a_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) + wq_lora_b_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) + layer.wq_.update_lora_weight( + adapter_name, "lora_A", wq_lora_a_weight) + layer.wq_.update_lora_weight( + adapter_name, "lora_B", wq_lora_b_weight) + layer.wq_.use_adapter_ = True + + if target_module["k_proj"] is True: + wk_lora_a_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) + wk_lora_b_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) + layer.wk_.update_lora_weight( + adapter_name, "lora_A", wk_lora_a_weight) + layer.wk_.update_lora_weight( + adapter_name, "lora_B", wk_lora_b_weight) + layer.wk_.use_adapter_ = True + + if target_module["v_proj"] is True: + wv_lora_a_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) + wv_lora_b_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) + layer.wv_.update_lora_weight( + adapter_name, "lora_A", wv_lora_a_weight) + layer.wv_.update_lora_weight( + adapter_name, "lora_B", wv_lora_b_weight) + layer.wv_.use_adapter_ = True + + if target_module["o_proj"] is True: + wo_lora_a_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16) + wo_lora_b_weight = torch.normal( + mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16) + layer.wo_.update_lora_weight( + adapter_name, "lora_A", wo_lora_a_weight) + layer.wo_.update_lora_weight( + adapter_name, "lora_B", wo_lora_b_weight) + layer.wo_.use_adapter_ = True + + +def save_lora_model(model: LlamaModel, path: str, lora_name: str): + lora_weight_dict = {} + for idx, layer in enumerate(model.layers_): + layer_prefix_name = "base_model.model.model.layers." + \ + str(idx) + "." + "self_attn." + if lora_name in layer.wq_.lora_a_: + lora_weight_dict[layer_prefix_name + + "q_proj.lora_A.weight"] = layer.wq_.lora_a_[lora_name] + if lora_name in layer.wq_.lora_b_: + lora_weight_dict[layer_prefix_name + + "q_proj.lora_B.weight"] = layer.wq_.lora_b_[lora_name] + if lora_name in layer.wk_.lora_a_: + lora_weight_dict[layer_prefix_name + + "k_proj.lora_A.weigth"] = layer.wk_.lora_a_[lora_name] + if lora_name in layer.wk_.lora_b_: + lora_weight_dict[layer_prefix_name + + "k_proj.lora_B.weight"] = layer.wk_.lora_b_[lora_name] + if lora_name in layer.wv_.lora_a_: + lora_weight_dict[layer_prefix_name + + "v_proj.lora_A.weight"] = layer.wv_.lora_a_[lora_name] + if lora_name in layer.wv_.lora_b_: + lora_weight_dict[layer_prefix_name + + "v_proj.lora_B.weight"] = layer.wv_.lora_b_[lora_name] + if lora_name in layer.wo_.lora_a_: + lora_weight_dict[layer_prefix_name + + "o_proj.lora_A.weight"] = layer.wo_.lora_a_[lora_name] + if lora_name in layer.wo_.lora_b_: + lora_weight_dict[layer_prefix_name + + "o_proj.lora_B.weight"] = layer.wo_.lora_b_[lora_name] + + torch.save(lora_weight_dict, path) diff --git a/aspen/tokenizer.py b/aspen/tokenizer.py new file mode 100644 index 00000000..eb02105c --- /dev/null +++ b/aspen/tokenizer.py @@ -0,0 +1,30 @@ +from aspen.modelargs import TokenizerArgs + +from sentencepiece import SentencePieceProcessor +from typing import List + + +class Tokenizer: + def __init__(self, model_path: str): + self.token_model_ = SentencePieceProcessor(model_file=model_path) + self.n_words_ = self.token_model_.vocab_size() + self.bos_id_ = self.token_model_.bos_id() + self.eos_id_ = self.token_model_.eos_id() + self.pad_id_ = self.token_model_.pad_id() + + def encode(self, data: str, bos: bool, eos: bool) -> List[int]: + ret = self.token_model_.encode(data) + if bos: + ret = [self.bos_id_] + ret + if eos: + ret = ret + [self.eos_id_] + return ret + + def decode(self, data: List[int]) -> str: + return self.token_model_.decode(data) + + def get_args(self) -> TokenizerArgs: + return TokenizerArgs(vocab_size_=self.n_words_, + bos_id_=self.bos_id_, + eos_id_=self.eos_id_, + pad_id_=self.pad_id_) diff --git a/aspen/utils.py b/aspen/utils.py new file mode 100644 index 00000000..5e908d62 --- /dev/null +++ b/aspen/utils.py @@ -0,0 +1,8 @@ +import torch +from transformers import LlamaForCausalLM + +# convert huggingface model to pytorch model +def convert_hf_to_pth(source: str, dest: str): + src_model = LlamaForCausalLM.from_pretrained(source) + # src_model.eval() + torch.save(src_model.state_dict(), dest) diff --git a/config/lora.json b/config/lora.json new file mode 100644 index 00000000..db688301 --- /dev/null +++ b/config/lora.json @@ -0,0 +1,46 @@ +{ + "base_model": "/yezhengmao/modules/llama-7b/7B/consolidated.00.pth", + "token_model": "/yezhengmao/modules/llama-7b/tokenizer.model", + "cutoff_len": 512, + "group_by_length": false, + "expand_right": true, + "device": "cuda:1", + "lora": [ + { + "name": "lora_0", + "output": "lora_0.bin", + "batch_size": 16, + "num_epochs": 3, + "r": 8, + "alpha": 16, + "dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "o_proj": true + }, + "data": "data/train_lora_a.json", + "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n" + }, + { + "name": "lora_1", + "output": "lora_1.bin", + "batch_size": 16, + "num_epochs": 3, + "r": 8, + "alpha": 16, + "dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "o_proj": true + }, + "data": "data/train_lora_b.json", + "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n", + "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n" + } + ] +} \ No newline at end of file diff --git a/data/demo.json b/data/demo.json deleted file mode 100644 index 8b137891..00000000 --- a/data/demo.json +++ /dev/null @@ -1 +0,0 @@ - diff --git a/mlora.py b/mlora.py new file mode 100644 index 00000000..9a33a769 --- /dev/null +++ b/mlora.py @@ -0,0 +1,94 @@ +import json +import torch +from aspen import LlamaModel, Tokenizer, DataSet +from aspen import LlamaModelArgs, MultiLoraBatchData +from aspen import load_llama_7b_weight, load_random_lora_7b_weight +from aspen import save_lora_model +import torch.optim + +with open('config/lora.json', 'r', encoding='utf8') as fp: + config = json.load(fp) + +args = LlamaModelArgs() +tokenizer = Tokenizer(config["token_model"]) +tokenizer.pad_id_ = 0 +args.max_seq_len_ = 4096 +args.device = config["device"] +args.vocab_size_ = tokenizer.n_words_ +args.pad_id_ = tokenizer.pad_id_ +args.n_heads_ = 32 +llama_model = LlamaModel(args) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def init_lora_model(llama_model: LlamaModel): + for lora_config in config["lora"]: + load_random_lora_7b_weight( + llama_model, + lora_config["name"], + lora_config["r"], + llama_model.dim_, + lora_config["target_modules"], + llama_model.device_) + llama_model.update_lora_configure( + lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) + + +if __name__ == "__main__": + setup_seed(42) + + data_set = DataSet(config, tokenizer) + load_llama_7b_weight(llama_model, config["base_model"], config["device"]) + init_lora_model(llama_model) + + torch.cuda.empty_cache() + + # optim begin + optimizer = torch.optim.SGD( + llama_model.get_train_paramas(config), lr=1e-3) + # optim end + + step = 0 + # torch.autograd.set_detect_anomaly(True) + while not data_set.check_done(): + optimizer.zero_grad() + loss_fn = torch.nn.CrossEntropyLoss() + input: MultiLoraBatchData = data_set.get_batch_data() + + step += 1 + + output = llama_model.forward(input) + labels = torch.tensor(input.batch_tokens_, + dtype=torch.long).to(config["device"]) + + total_loss = None + for lora_config in input.lora_batch_data_config_: + start_idx = lora_config.batch_start_idx_ + end_idx = lora_config.batch_end_idx_ + loss_input = output[start_idx:end_idx][..., :-1, + :].contiguous().view(-1, llama_model.vocab_size_) + loss_target = labels[start_idx:end_idx][..., + 1:].contiguous().view(-1) + loss = loss_fn(loss_input, loss_target) + print( + f" adapter: {lora_config.adapter_name_} loss: {loss}") + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + total_loss.backward() + optimizer.step() + + if step % 200 == 0: + for lora_config in config["lora"]: + save_lora_model( + llama_model, lora_config["output"] + f".chk{step}", lora_config["name"]) + + for lora_config in config["lora"]: + save_lora_model( + llama_model, lora_config["output"], lora_config["name"]) diff --git a/requirements.txt b/requirements.txt index 11fb9741..9a5baec6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -torch==1.8.1 -transformers==4.6.0 -#... other dependencies +torch==2.1.0 +sentencepiece +xformers +einops \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..6d51f744 --- /dev/null +++ b/setup.py @@ -0,0 +1,5 @@ +from setuptools import setup, find_packages + +setup(name='aspen', + version='0.2', + packages=find_packages(exclude=["test"])) From 46d464fbe0e08a1bd8c193d9a394ce593b59cc5f Mon Sep 17 00:00:00 2001 From: Mingjie Tang Date: Tue, 29 Aug 2023 01:17:12 +0800 Subject: [PATCH 2/2] Update requirements.txt --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9a5baec6..59db177c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.1.0 +torch==2.0.1 sentencepiece xformers -einops \ No newline at end of file +einops