From f5972cac5ad30c995090b305c4f5f13373ec55f7 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 29 Aug 2023 11:46:06 +0800 Subject: [PATCH] Some minor improvement (#10) * Rename mlora.py -> legacy.py * Add bitsandbytes to requirements.txt * Create wrapped entrance code * Update README.md * Merge changes from @yezhengmao --- README.md | 44 +++++++++++++-- legacy.py | 91 ++++++++++++++++++++++++++++++ mlora.py | 144 ++++++++++++++++++----------------------------- requirements.txt | 1 + 4 files changed, 186 insertions(+), 94 deletions(-) create mode 100644 legacy.py diff --git a/README.md b/README.md index 190c7c47..c8511ae8 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,27 @@ This repository provides tools for fine-tuning large language models (LLMs) usin - [Updates](#updates) - [Overview](#overview) +- [Installation](#Installation) - [Getting Started](#Quickstart) -- [Contributing](#contributing) -- [License](#license) +- [Contributing](#Contributing) +- [Copyright](#Copyright) ## Updates Support ## Overview - +## Installation +```bash +# Optional but recommended +conda create -n aspen_env python=3.6 +# Install requirements +pip install -r requirements.txt +``` +After installation, you can use ASPEN directly in your code: +```python +import aspen +``` ## Quickstart The `mlora.py` code is a starting point for finetuning and inference on various datasets. @@ -27,7 +38,12 @@ python mlora.py --model_name_or_path For models larger than 13B, we recommend adjusting the learning rate: ```bash -python mlora.py –learning_rate 0.0001 --model_name_or_path +python mlora.py -–learning_rate 0.0001 --model_name_or_path +``` + +You can check detailed usage information by `--help` option: +```bash +python mlora.py --help ``` ## Contributing @@ -48,5 +64,21 @@ Submit a pull request with a detailed explanation of your changes. } ``` -## License -This project is licensed under the Apache 2.0 License - see the LICENSE file for details +## Copyright +Copyright © 2023 All Rights Reserved. + +This project is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). + +``` +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` \ No newline at end of file diff --git a/legacy.py b/legacy.py new file mode 100644 index 00000000..dc1d9335 --- /dev/null +++ b/legacy.py @@ -0,0 +1,91 @@ +from aspen import LlamaModel, Tokenizer, DataSet +from aspen import LlamaModelArgs, MultiLoraBatchData +from aspen import load_llama_7b_weight, load_random_lora_7b_weight +from aspen import save_lora_model + +import json +import torch +import torch.optim + +with open('config/lora.json', 'r', encoding='utf8') as fp: + config = json.load(fp) + +args = LlamaModelArgs() +tokenizer = Tokenizer(config["token_model"]) +tokenizer.pad_id_ = 0 +args.max_seq_len_ = 4096 +args.device = config["device"] +args.vocab_size_ = tokenizer.n_words_ +args.pad_id_ = tokenizer.pad_id_ +args.n_heads_ = 32 +llama_model = LlamaModel(args) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def init_lora_model(llama_model: LlamaModel): + for lora_config in config["lora"]: + load_random_lora_7b_weight( + llama_model, + lora_config["name"], + lora_config["r"], + llama_model.dim_, + lora_config["target_modules"], + llama_model.device_) + llama_model.update_lora_configure( + lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) + + +if __name__ == "__main__": + setup_seed(42) + + data_set = DataSet(config, tokenizer) + load_llama_7b_weight(llama_model, config["base_model"], config["device"]) + init_lora_model(llama_model) + + torch.cuda.empty_cache() + + optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) + + step_cnt = 0 + while not data_set.check_done(): + optimizer.zero_grad() + loss_fn = torch.nn.CrossEntropyLoss() + input: MultiLoraBatchData = data_set.get_batch_data() + + step_cnt += 1 + + output = llama_model.forward(input) + labels = torch.tensor(input.batch_tokens_, + dtype=torch.long).to(config["device"]) + + total_loss = None + for lora_config in input.lora_batch_data_config_: + start_idx = lora_config.batch_start_idx_ + end_idx = lora_config.batch_end_idx_ + loss_input = output[start_idx:end_idx][..., :-1, + :].contiguous().view(-1, llama_model.vocab_size_) + loss_target = labels[start_idx:end_idx][..., + 1:].contiguous().view(-1) + loss = loss_fn(loss_input, loss_target) + print( + f" adapter: {lora_config.adapter_name_} loss: {loss}") + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + total_loss.backward() + optimizer.step() + + if step_cnt % config["save_step"] == 0: + for lora_config in config["lora"]: + save_lora_model( + llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) + + for lora_config in config["lora"]: + save_lora_model( + llama_model, lora_config["output"], lora_config["name"]) \ No newline at end of file diff --git a/mlora.py b/mlora.py index 486ef3f0..7ccf0c6f 100644 --- a/mlora.py +++ b/mlora.py @@ -1,91 +1,59 @@ -from aspen import LlamaModel, Tokenizer, DataSet -from aspen import LlamaModelArgs, MultiLoraBatchData -from aspen import load_llama_7b_weight, load_random_lora_7b_weight -from aspen import save_lora_model - -import json +# ASPEN: Efficient Multi-LoRA Fine Tuning with Shared-Based Model +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (C) 2023 All Rights Reserved. +# +# Email: +# Github: https://github.com/TUDB-Labs/multi-lora-fine-tune +# Website: + +import datetime +import argparse import torch -import torch.optim - -with open('config/lora.json', 'r', encoding='utf8') as fp: - config = json.load(fp) - -args = LlamaModelArgs() -tokenizer = Tokenizer(config["token_model"]) -tokenizer.pad_id_ = 0 -args.max_seq_len_ = 4096 -args.device = config["device"] -args.vocab_size_ = tokenizer.n_words_ -args.pad_id_ = tokenizer.pad_id_ -args.n_heads_ = 32 -llama_model = LlamaModel(args) - - -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def init_lora_model(llama_model: LlamaModel): - for lora_config in config["lora"]: - load_random_lora_7b_weight( - llama_model, - lora_config["name"], - lora_config["r"], - llama_model.dim_, - lora_config["target_modules"], - llama_model.device_) - llama_model.update_lora_configure( - lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) - +import aspen +import os + +parser = argparse.ArgumentParser(description='ASPEN main program') +parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') +parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') +parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') + +args = parser.parse_args() + +def log(msg:str): + if args.log: + print('[%s] ASPEN: %s' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg)) + +if torch.cuda.is_available(): + log('NVIDIA CUDA initialized successfully.') + log('Total %i GPU(s) detected.' % torch.cuda.device_count()) +else: + print('ASPEN requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.') + exit(-1) + +def prep_llm(): + args = aspen.LlamaModelArgs() + tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') + tokenizer.pad_id_ = 0 + args.max_seq_len_ = 4096 + args.device = args.device + args.vocab_size_ = tokenizer.n_words_ + args.pad_id_ = tokenizer.pad_id_ + args.n_heads_ = 32 + model = aspen.LlamaModel(args) + aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) + return tokenizer, model if __name__ == "__main__": - setup_seed(42) - - data_set = DataSet(config, tokenizer) - load_llama_7b_weight(llama_model, config["base_model"], config["device"]) - init_lora_model(llama_model) - - torch.cuda.empty_cache() - - optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) - - step_cnt = 0 - while not data_set.check_done(): - optimizer.zero_grad() - loss_fn = torch.nn.CrossEntropyLoss() - input: MultiLoraBatchData = data_set.get_batch_data() - - step_cnt += 1 - - output = llama_model.forward(input) - labels = torch.tensor(input.batch_tokens_, - dtype=torch.long).to(config["device"]) - - total_loss = None - for lora_config in input.lora_batch_data_config_: - start_idx = lora_config.batch_start_idx_ - end_idx = lora_config.batch_end_idx_ - loss_input = output[start_idx:end_idx][..., :-1, - :].contiguous().view(-1, llama_model.vocab_size_) - loss_target = labels[start_idx:end_idx][..., - 1:].contiguous().view(-1) - loss = loss_fn(loss_input, loss_target) - print( - f" adapter: {lora_config.adapter_name_} loss: {loss}") - if total_loss is None: - total_loss = loss - else: - total_loss += loss - - total_loss.backward() - optimizer.step() - - if step_cnt % config["save_step"] == 0: - for lora_config in config["lora"]: - save_lora_model( - llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) - - for lora_config in config["lora"]: - save_lora_model( - llama_model, lora_config["output"], lora_config["name"]) + tokenizer, model = prep_llm() diff --git a/requirements.txt b/requirements.txt index 59db177c..880e0af4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torch==2.0.1 sentencepiece +bitsandbytes xformers einops