From 44ff3b9f7cfed1bf37f9b2cb5db03689cd8a6bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 09:41:12 +0800 Subject: [PATCH 01/10] Rename mlora.py -> legacy.py --- mlora.py => legacy.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mlora.py => legacy.py (100%) diff --git a/mlora.py b/legacy.py similarity index 100% rename from mlora.py rename to legacy.py From 843e20b29d2a373f18558cd125e10036a9d3e772 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 10:31:01 +0800 Subject: [PATCH 02/10] Add bitsandbytes to requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 59db177c..880e0af4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch==2.0.1 sentencepiece +bitsandbytes xformers einops From 806e0caf29ab3d5f377d7de553d1e2e463358831 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 10:31:42 +0800 Subject: [PATCH 03/10] Create wrapped entrance code --- mlora.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 mlora.py diff --git a/mlora.py b/mlora.py new file mode 100644 index 00000000..7ccf0c6f --- /dev/null +++ b/mlora.py @@ -0,0 +1,59 @@ +# ASPEN: Efficient Multi-LoRA Fine Tuning with Shared-Based Model +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (C) 2023 All Rights Reserved. +# +# Email: +# Github: https://github.com/TUDB-Labs/multi-lora-fine-tune +# Website: + +import datetime +import argparse +import torch +import aspen +import os + +parser = argparse.ArgumentParser(description='ASPEN main program') +parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') +parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') +parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') + +args = parser.parse_args() + +def log(msg:str): + if args.log: + print('[%s] ASPEN: %s' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg)) + +if torch.cuda.is_available(): + log('NVIDIA CUDA initialized successfully.') + log('Total %i GPU(s) detected.' % torch.cuda.device_count()) +else: + print('ASPEN requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.') + exit(-1) + +def prep_llm(): + args = aspen.LlamaModelArgs() + tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') + tokenizer.pad_id_ = 0 + args.max_seq_len_ = 4096 + args.device = args.device + args.vocab_size_ = tokenizer.n_words_ + args.pad_id_ = tokenizer.pad_id_ + args.n_heads_ = 32 + model = aspen.LlamaModel(args) + aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) + return tokenizer, model + +if __name__ == "__main__": + tokenizer, model = prep_llm() From 2df9e8bd80217cd17dff4aaca5ffbbe10ea9afd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 10:31:44 +0800 Subject: [PATCH 04/10] Update README.md --- README.md | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 190c7c47..c8511ae8 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,27 @@ This repository provides tools for fine-tuning large language models (LLMs) usin - [Updates](#updates) - [Overview](#overview) +- [Installation](#Installation) - [Getting Started](#Quickstart) -- [Contributing](#contributing) -- [License](#license) +- [Contributing](#Contributing) +- [Copyright](#Copyright) ## Updates Support ## Overview - +## Installation +```bash +# Optional but recommended +conda create -n aspen_env python=3.6 +# Install requirements +pip install -r requirements.txt +``` +After installation, you can use ASPEN directly in your code: +```python +import aspen +``` ## Quickstart The `mlora.py` code is a starting point for finetuning and inference on various datasets. @@ -27,7 +38,12 @@ python mlora.py --model_name_or_path For models larger than 13B, we recommend adjusting the learning rate: ```bash -python mlora.py –learning_rate 0.0001 --model_name_or_path +python mlora.py -–learning_rate 0.0001 --model_name_or_path +``` + +You can check detailed usage information by `--help` option: +```bash +python mlora.py --help ``` ## Contributing @@ -48,5 +64,21 @@ Submit a pull request with a detailed explanation of your changes. } ``` -## License -This project is licensed under the Apache 2.0 License - see the LICENSE file for details +## Copyright +Copyright © 2023 All Rights Reserved. + +This project is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). + +``` +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` \ No newline at end of file From 5c6068510af8e3a00dbc8491a6e43a79b9f71139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 11:41:29 +0800 Subject: [PATCH 05/10] Merge changes from @yezhengmao --- legacy.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/legacy.py b/legacy.py index 9a33a769..dc1d9335 100644 --- a/legacy.py +++ b/legacy.py @@ -1,9 +1,10 @@ -import json -import torch from aspen import LlamaModel, Tokenizer, DataSet from aspen import LlamaModelArgs, MultiLoraBatchData from aspen import load_llama_7b_weight, load_random_lora_7b_weight from aspen import save_lora_model + +import json +import torch import torch.optim with open('config/lora.json', 'r', encoding='utf8') as fp: @@ -47,19 +48,15 @@ def init_lora_model(llama_model: LlamaModel): torch.cuda.empty_cache() - # optim begin - optimizer = torch.optim.SGD( - llama_model.get_train_paramas(config), lr=1e-3) - # optim end + optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) - step = 0 - # torch.autograd.set_detect_anomaly(True) + step_cnt = 0 while not data_set.check_done(): optimizer.zero_grad() loss_fn = torch.nn.CrossEntropyLoss() input: MultiLoraBatchData = data_set.get_batch_data() - step += 1 + step_cnt += 1 output = llama_model.forward(input) labels = torch.tensor(input.batch_tokens_, @@ -84,11 +81,11 @@ def init_lora_model(llama_model: LlamaModel): total_loss.backward() optimizer.step() - if step % 200 == 0: + if step_cnt % config["save_step"] == 0: for lora_config in config["lora"]: save_lora_model( - llama_model, lora_config["output"] + f".chk{step}", lora_config["name"]) + llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) for lora_config in config["lora"]: save_lora_model( - llama_model, lora_config["output"], lora_config["name"]) + llama_model, lora_config["output"], lora_config["name"]) \ No newline at end of file From 69c7a2bfb2dcc84fd2f59427ebdbc4513d69cd2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Tue, 29 Aug 2023 16:43:42 +0800 Subject: [PATCH 06/10] Add 8bit bypass --- aspen/model.py | 15 +++++++++------ aspen/modelloader.py | 18 +++++++++--------- mlora.py | 20 +++++++++++++------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/aspen/model.py b/aspen/model.py index 051c83de..00cb6523 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -7,7 +7,6 @@ import xformers.ops import xformers.ops.fmha.attn_bias from typing import List, Dict, Set, Tuple -from bitsandbytes.nn import Linear8bitLt, Int8Params def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]: @@ -97,12 +96,16 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): - def __init__(self, weight: torch.Tensor): + def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True): row, col = weight.shape - self.weight_ = Linear8bitLt( - input_features=col, output_features=row, bias=False, has_fp16_weights=False) - self.weight_.weight = Int8Params( - weight.data, requires_grad=False).cuda(weight.device) + if load_in_8bit: + from bitsandbytes.nn import Linear8bitLt, Int8Params + self.weight_ = Linear8bitLt( + input_features=col, output_features=row, bias=False, has_fp16_weights=False) + self.weight_.weight = Int8Params( + weight.data, requires_grad=False).cuda(weight.device) + else: + self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False) self.use_adapter_: bool = False # adapter list self.adapter_names_: Set[str] = set() diff --git a/aspen/modelloader.py b/aspen/modelloader.py index 24769c7a..ee53ed2f 100644 --- a/aspen/modelloader.py +++ b/aspen/modelloader.py @@ -48,9 +48,9 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): print(f"Not use layer {layer_name}.", file=sys.stderr) -def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): +def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False): weight = LlamaForCausalLM.from_pretrained( - llama_model_path, device_map=dev).state_dict() + llama_model_path, device_map=dev, load_in_8bit=load_in_8bit).state_dict() for layer_name in weight: w: torch.Tensor = weight[layer_name] @@ -60,19 +60,19 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): layer_name = layer_name[len("model.layers."):] layer_id = int(layer_name[:layer_name.find(".")]) if "self_attn.q_proj" in layer_name: - model.layers_[layer_id].wq_ = Linear(w) + model.layers_[layer_id].wq_ = Linear(w, load_in_8bit) elif "self_attn.k_proj" in layer_name: - model.layers_[layer_id].wk_ = Linear(w) + model.layers_[layer_id].wk_ = Linear(w, load_in_8bit) elif "self_attn.v_proj" in layer_name: - model.layers_[layer_id].wv_ = Linear(w) + model.layers_[layer_id].wv_ = Linear(w, load_in_8bit) elif "self_attn.o_proj" in layer_name: - model.layers_[layer_id].wo_ = Linear(w) + model.layers_[layer_id].wo_ = Linear(w, load_in_8bit) elif "mlp.gate_proj" in layer_name: - model.layers_[layer_id].w1_ = Linear(w) + model.layers_[layer_id].w1_ = Linear(w, load_in_8bit) elif "mlp.down_proj" in layer_name: - model.layers_[layer_id].w2_ = Linear(w) + model.layers_[layer_id].w2_ = Linear(w, load_in_8bit) elif "mlp.up_proj" in layer_name: - model.layers_[layer_id].w3_ = Linear(w) + model.layers_[layer_id].w3_ = Linear(w, load_in_8bit) elif "input_layernorm" in layer_name: model.layers_[layer_id].attention_norm_ = RMSNorm( w, model.norm_eps_) diff --git a/mlora.py b/mlora.py index b877eee1..24a6f1dd 100644 --- a/mlora.py +++ b/mlora.py @@ -43,16 +43,22 @@ def log(msg: str): exit(-1) +if args.model_name_or_path == None: + print('error: Argument --model_name_or_path are required.') + parser.print_help() + exit(-1) + + def prep_llm(): - args = aspen.LlamaModelArgs() + llama_args = aspen.LlamaModelArgs() tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') tokenizer.pad_id_ = 0 - args.max_seq_len_ = 4096 - args.device = args.device - args.vocab_size_ = tokenizer.n_words_ - args.pad_id_ = tokenizer.pad_id_ - args.n_heads_ = 32 - model = aspen.LlamaModel(args) + llama_args.max_seq_len_ = 4096 + llama_args.device = args.device + llama_args.vocab_size_ = tokenizer.n_words_ + llama_args.pad_id_ = tokenizer.pad_id_ + llama_args.n_heads_ = 32 + model = aspen.LlamaModel(llama_args) aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) return tokenizer, model From d897aafb65f3c00811e960eb50eb3e5ecf389917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 09:14:32 +0800 Subject: [PATCH 07/10] Update mlora.py --- mlora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlora.py b/mlora.py index 24a6f1dd..1fdd3f5d 100644 --- a/mlora.py +++ b/mlora.py @@ -24,6 +24,7 @@ parser = argparse.ArgumentParser(description='ASPEN main program') parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') +parser.add_argument('--load_in_8bit', type=bool, default=False, help='Load model in 8bit mode') parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') @@ -59,7 +60,7 @@ def prep_llm(): llama_args.pad_id_ = tokenizer.pad_id_ llama_args.n_heads_ = 32 model = aspen.LlamaModel(llama_args) - aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) + aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device, args.load_in_8bit) return tokenizer, model From 0c712e0c568112d00bf48f76de0a290bb142c444 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 09:14:36 +0800 Subject: [PATCH 08/10] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c8511ae8..9aebc00c 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Support ```bash # Optional but recommended conda create -n aspen_env python=3.6 +conda activate aspen_env # Install requirements pip install -r requirements.txt ``` From 05b5cf21dbf413c85e2d4363c0cbe8899ee3d8d9 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 30 Aug 2023 09:45:37 +0800 Subject: [PATCH 09/10] Fix unresonable memory allocation --- aspen/model.py | 10 ++++++---- aspen/modelloader.py | 27 +++++++++++++-------------- requirements.txt | 1 + 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/aspen/model.py b/aspen/model.py index 00cb6523..3f5a75f5 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -96,16 +96,18 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): - def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True): + def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None): + if device == None: + device = weight.device row, col = weight.shape if load_in_8bit: from bitsandbytes.nn import Linear8bitLt, Int8Params self.weight_ = Linear8bitLt( - input_features=col, output_features=row, bias=False, has_fp16_weights=False) + input_features=col, output_features=row, bias=False, has_fp16_weights=False, device=device) self.weight_.weight = Int8Params( - weight.data, requires_grad=False).cuda(weight.device) + weight.data, requires_grad=False).cuda(device) else: - self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False) + self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False, device=device) self.use_adapter_: bool = False # adapter list self.adapter_names_: Set[str] = set() diff --git a/aspen/modelloader.py b/aspen/modelloader.py index 0d2e5ae4..85d46fa4 100644 --- a/aspen/modelloader.py +++ b/aspen/modelloader.py @@ -52,8 +52,7 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False): - weight = LlamaForCausalLM.from_pretrained( - llama_model_path, device_map=dev, load_in_8bit=load_in_8bit).state_dict() + weight = LlamaForCausalLM.from_pretrained(llama_model_path).state_dict(keep_vars=True) for layer_name in weight: w: torch.Tensor = weight[layer_name] @@ -63,34 +62,34 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, loa layer_name = layer_name[len("model.layers."):] layer_id = int(layer_name[:layer_name.find(".")]) if "self_attn.q_proj" in layer_name: - model.layers_[layer_id].wq_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wq_ = Linear(w, load_in_8bit, dev) elif "self_attn.k_proj" in layer_name: - model.layers_[layer_id].wk_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wk_ = Linear(w, load_in_8bit, dev) elif "self_attn.v_proj" in layer_name: - model.layers_[layer_id].wv_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wv_ = Linear(w, load_in_8bit, dev) elif "self_attn.o_proj" in layer_name: - model.layers_[layer_id].wo_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wo_ = Linear(w, load_in_8bit, dev) elif "mlp.gate_proj" in layer_name: - model.layers_[layer_id].w1_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w1_ = Linear(w, load_in_8bit, dev) elif "mlp.down_proj" in layer_name: - model.layers_[layer_id].w2_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w2_ = Linear(w, load_in_8bit, dev) elif "mlp.up_proj" in layer_name: - model.layers_[layer_id].w3_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w3_ = Linear(w, load_in_8bit, dev) elif "input_layernorm" in layer_name: model.layers_[layer_id].attention_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) elif "post_attention_layernorm" in layer_name: model.layers_[layer_id].ffn_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) else: print( f"Not use layer model.layers.{layer_name}.", file=sys.stderr) elif "embed_tokens" in layer_name: - model.token_embedding_ = w + model.token_embedding_ = w.to(device=dev) elif "norm.weight" in layer_name: - model.norm_ = RMSNorm(w, model.norm_eps_) + model.norm_ = RMSNorm(w.to(device=dev), model.norm_eps_) elif "lm_head.weight" in layer_name: - model.output_ = w.to(torch.float32) + model.output_ = w.to(dtype=torch.float32, device=dev) else: print(f"Not use layer {layer_name}.", file=sys.stderr) diff --git a/requirements.txt b/requirements.txt index 625740b2..bee2672e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ transformers bitsandbytes sentencepiece scipy +accelerate \ No newline at end of file From 55092b8caca14d1a07e535a5515343d365f9b69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 09:57:26 +0800 Subject: [PATCH 10/10] Fix problem reported by lint --- aspen/model.py | 2 +- mlora.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aspen/model.py b/aspen/model.py index 3f5a75f5..6c3dff03 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -97,7 +97,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None): - if device == None: + if device is None: device = weight.device row, col = weight.shape if load_in_8bit: diff --git a/mlora.py b/mlora.py index 1fdd3f5d..a6d71146 100644 --- a/mlora.py +++ b/mlora.py @@ -44,7 +44,7 @@ def log(msg: str): exit(-1) -if args.model_name_or_path == None: +if args.model_name_or_path is None: print('error: Argument --model_name_or_path are required.') parser.print_help() exit(-1)