From d2a73ccc85efbe2c9075e49ab4d96e23fccb0008 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 13:27:14 +0800 Subject: [PATCH 1/3] update mlora.py --- config/finetune.json | 48 +++++++++++++++++++ mlora.py | 111 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 config/finetune.json diff --git a/config/finetune.json b/config/finetune.json new file mode 100644 index 00000000..d347c5e0 --- /dev/null +++ b/config/finetune.json @@ -0,0 +1,48 @@ +{ + "cutoff_len": 256, + "group_by_length": false, + "expand_right": true, + "save_step": 200, + "lora": [ + { + "name": "lora_0", + "output": "lora_0", + "batch_size": 16, + "num_epochs": 3, + "r": 8, + "alpha": 16, + "dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "o_proj": true, + "w1_proj": false, + "w2_proj": false, + "w3_proj": false + }, + "data": "data/data_demo.json", + "prompt": "template/template_demo.json" + }, + { + "name": "lora_1", + "output": "lora_1", + "batch_size": 16, + "num_epochs": 3, + "r": 8, + "alpha": 16, + "dropout": 0.05, + "target_modules": { + "q_proj": true, + "k_proj": true, + "v_proj": true, + "o_proj": true, + "w1_proj": false, + "w2_proj": false, + "w3_proj": false + }, + "data": "data/data_demo.json", + "prompt": "template/template_demo.json" + } + ] +} \ No newline at end of file diff --git a/mlora.py b/mlora.py index a6d71146..7fd20166 100644 --- a/mlora.py +++ b/mlora.py @@ -20,20 +20,32 @@ import argparse import torch import aspen +import json import os +# Command Line Arguments + parser = argparse.ArgumentParser(description='ASPEN main program') -parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') -parser.add_argument('--load_in_8bit', type=bool, default=False, help='Load model in 8bit mode') -parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') -parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') +parser.add_argument('--base_model', type=str, + help='Path to or name of base model') +parser.add_argument('--load_8bit', type=bool, default=False, + help='Load model in 8bit mode') +parser.add_argument('--device', type=str, default='cuda:0', + help='Specify which GPU to be used, default is cuda:0') +parser.add_argument('--config', type=str, + help='Path to finetune configuration') +parser.add_argument('--seed', type=int, default=42, + help='Random seed in integer, default is 42') +parser.add_argument('--log', type=bool, default=True, + help='Turn on or off log, default is true') args = parser.parse_args() def log(msg: str): if args.log: - print('[%s] ASPEN: %s' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg)) + print('[%s] ASPEN: %s' % + (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg)) if torch.cuda.is_available(): @@ -44,15 +56,29 @@ def log(msg: str): exit(-1) -if args.model_name_or_path is None: - print('error: Argument --model_name_or_path are required.') +if args.base_model is None: + print('error: Argument --base_model are required.') + parser.print_help() + exit(-1) + + +if args.config is None: + print('error: Argument --config are required.') parser.print_help() exit(-1) -def prep_llm(): +# Functions + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_base_model(): llama_args = aspen.LlamaModelArgs() - tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') + tokenizer = aspen.Tokenizer(args.base_model + os.sep + 'tokenizer.model') tokenizer.pad_id_ = 0 llama_args.max_seq_len_ = 4096 llama_args.device = args.device @@ -60,9 +86,72 @@ def prep_llm(): llama_args.pad_id_ = tokenizer.pad_id_ llama_args.n_heads_ = 32 model = aspen.LlamaModel(llama_args) - aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device, args.load_in_8bit) + aspen.load_llama_tf_weight( + model, args.base_model, args.device, args.load_8bit) return tokenizer, model +def init_lora_model(config: dict, llama_model: aspen.LlamaModel): + for lora_config in config["lora"]: + aspen.load_random_lora_7b_weight( + llama_model, + lora_config["name"], + lora_config["r"], + llama_model.dim_, + lora_config["target_modules"], + llama_model.device_) + llama_model.update_lora_configure( + lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) + + +def train(config: dict, llama_model: aspen.LlamaModel): + torch.cuda.empty_cache() + optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) + step_cnt = 0 + while not data_set.check_done(): + optimizer.zero_grad() + loss_fn = torch.nn.CrossEntropyLoss() + input: aspen.MultiLoraBatchData = data_set.get_batch_data() + + step_cnt += 1 + + output = llama_model.forward(input) + labels = torch.tensor(input.batch_tokens_, + dtype=torch.long).to(args.device) + + total_loss = None + for lora_config in input.lora_batch_data_config_: + start_idx = lora_config.batch_start_idx_ + end_idx = lora_config.batch_end_idx_ + loss_input = output[start_idx:end_idx][..., :-1, + :].contiguous().view(-1, llama_model.vocab_size_) + loss_target = labels[start_idx:end_idx][..., + 1:].contiguous().view(-1) + loss = loss_fn(loss_input, loss_target) + print( + f" adapter: {lora_config.adapter_name_} loss: {loss}") + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + total_loss.backward() + optimizer.step() + + if step_cnt % config["save_step"] == 0: + aspen.save_lora_model(llama_model, config, f"{step_cnt}") + + aspen.save_lora_model(llama_model, config) + + +# Main Function + + if __name__ == "__main__": - tokenizer, model = prep_llm() + setup_seed(args.seed) + with open(args.config, 'r', encoding='utf8') as fp: + config = json.load(fp) + tokenizer, model = load_base_model() + data_set = aspen.DataSet(config, tokenizer) + init_lora_model(config, model) + train(config, model) From c8e4fd08f8c0abbc230e997e9c21b453d0c68090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 13:33:23 +0800 Subject: [PATCH 2/3] Update README.md --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9aebc00c..56776986 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Support ## Installation ```bash # Optional but recommended -conda create -n aspen_env python=3.6 +conda create -n aspen_env python=3.8 conda activate aspen_env # Install requirements pip install -r requirements.txt @@ -31,18 +31,18 @@ import aspen ``` ## Quickstart -The `mlora.py` code is a starting point for finetuning and inference on various datasets. +The `mlora.py` code is a starting point for finetuning on various datasets. Basic command for finetuning a baseline model on the Alpaca dataset: ```bash -python mlora.py --model_name_or_path +python mlora.py \ + --base_model decapoda-research/llama-7b-hf \ + --config ./config/alpaca.json \ + --load_8bit true ``` -For models larger than 13B, we recommend adjusting the learning rate: -```bash -python mlora.py -–learning_rate 0.0001 --model_name_or_path -``` +You can check the template finetune configuration in [config](./config/) folder. -You can check detailed usage information by `--help` option: +For further detailed usage information, please use `--help` option: ```bash python mlora.py --help ``` From ab46a085fcfbf3596d686edef2f70eb57aa00565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=99=BB=E6=B7=B3?= Date: Wed, 30 Aug 2023 13:38:01 +0800 Subject: [PATCH 3/3] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 56776986..f434b019 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ python mlora.py \ --load_8bit true ``` -You can check the template finetune configuration in [config](./config/) folder. +You can check the template finetune configuration in [template](./template/) folder. For further detailed usage information, please use `--help` option: ```bash