-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Rename mlora.py -> legacy.py * Add bitsandbytes to requirements.txt * Create wrapped entrance code * Update README.md * Merge changes from @yezhengmao
- Loading branch information
1 parent
98a31de
commit f5972ca
Showing
4 changed files
with
186 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from aspen import LlamaModel, Tokenizer, DataSet | ||
from aspen import LlamaModelArgs, MultiLoraBatchData | ||
from aspen import load_llama_7b_weight, load_random_lora_7b_weight | ||
from aspen import save_lora_model | ||
|
||
import json | ||
import torch | ||
import torch.optim | ||
|
||
with open('config/lora.json', 'r', encoding='utf8') as fp: | ||
config = json.load(fp) | ||
|
||
args = LlamaModelArgs() | ||
tokenizer = Tokenizer(config["token_model"]) | ||
tokenizer.pad_id_ = 0 | ||
args.max_seq_len_ = 4096 | ||
args.device = config["device"] | ||
args.vocab_size_ = tokenizer.n_words_ | ||
args.pad_id_ = tokenizer.pad_id_ | ||
args.n_heads_ = 32 | ||
llama_model = LlamaModel(args) | ||
|
||
|
||
def setup_seed(seed): | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
|
||
def init_lora_model(llama_model: LlamaModel): | ||
for lora_config in config["lora"]: | ||
load_random_lora_7b_weight( | ||
llama_model, | ||
lora_config["name"], | ||
lora_config["r"], | ||
llama_model.dim_, | ||
lora_config["target_modules"], | ||
llama_model.device_) | ||
llama_model.update_lora_configure( | ||
lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
setup_seed(42) | ||
|
||
data_set = DataSet(config, tokenizer) | ||
load_llama_7b_weight(llama_model, config["base_model"], config["device"]) | ||
init_lora_model(llama_model) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) | ||
|
||
step_cnt = 0 | ||
while not data_set.check_done(): | ||
optimizer.zero_grad() | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
input: MultiLoraBatchData = data_set.get_batch_data() | ||
|
||
step_cnt += 1 | ||
|
||
output = llama_model.forward(input) | ||
labels = torch.tensor(input.batch_tokens_, | ||
dtype=torch.long).to(config["device"]) | ||
|
||
total_loss = None | ||
for lora_config in input.lora_batch_data_config_: | ||
start_idx = lora_config.batch_start_idx_ | ||
end_idx = lora_config.batch_end_idx_ | ||
loss_input = output[start_idx:end_idx][..., :-1, | ||
:].contiguous().view(-1, llama_model.vocab_size_) | ||
loss_target = labels[start_idx:end_idx][..., | ||
1:].contiguous().view(-1) | ||
loss = loss_fn(loss_input, loss_target) | ||
print( | ||
f" adapter: {lora_config.adapter_name_} loss: {loss}") | ||
if total_loss is None: | ||
total_loss = loss | ||
else: | ||
total_loss += loss | ||
|
||
total_loss.backward() | ||
optimizer.step() | ||
|
||
if step_cnt % config["save_step"] == 0: | ||
for lora_config in config["lora"]: | ||
save_lora_model( | ||
llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) | ||
|
||
for lora_config in config["lora"]: | ||
save_lora_model( | ||
llama_model, lora_config["output"], lora_config["name"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,59 @@ | ||
from aspen import LlamaModel, Tokenizer, DataSet | ||
from aspen import LlamaModelArgs, MultiLoraBatchData | ||
from aspen import load_llama_7b_weight, load_random_lora_7b_weight | ||
from aspen import save_lora_model | ||
|
||
import json | ||
# ASPEN: Efficient Multi-LoRA Fine Tuning with Shared-Based Model | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Copyright (C) 2023 All Rights Reserved. | ||
# | ||
# Email: | ||
# Github: https://github.com/TUDB-Labs/multi-lora-fine-tune | ||
# Website: | ||
|
||
import datetime | ||
import argparse | ||
import torch | ||
import torch.optim | ||
|
||
with open('config/lora.json', 'r', encoding='utf8') as fp: | ||
config = json.load(fp) | ||
|
||
args = LlamaModelArgs() | ||
tokenizer = Tokenizer(config["token_model"]) | ||
tokenizer.pad_id_ = 0 | ||
args.max_seq_len_ = 4096 | ||
args.device = config["device"] | ||
args.vocab_size_ = tokenizer.n_words_ | ||
args.pad_id_ = tokenizer.pad_id_ | ||
args.n_heads_ = 32 | ||
llama_model = LlamaModel(args) | ||
|
||
|
||
def setup_seed(seed): | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
|
||
def init_lora_model(llama_model: LlamaModel): | ||
for lora_config in config["lora"]: | ||
load_random_lora_7b_weight( | ||
llama_model, | ||
lora_config["name"], | ||
lora_config["r"], | ||
llama_model.dim_, | ||
lora_config["target_modules"], | ||
llama_model.device_) | ||
llama_model.update_lora_configure( | ||
lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"]) | ||
|
||
import aspen | ||
import os | ||
|
||
parser = argparse.ArgumentParser(description='ASPEN main program') | ||
parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') | ||
parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') | ||
parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') | ||
|
||
args = parser.parse_args() | ||
|
||
def log(msg:str): | ||
if args.log: | ||
print('[%s] ASPEN: %s' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg)) | ||
|
||
if torch.cuda.is_available(): | ||
log('NVIDIA CUDA initialized successfully.') | ||
log('Total %i GPU(s) detected.' % torch.cuda.device_count()) | ||
else: | ||
print('ASPEN requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.') | ||
exit(-1) | ||
|
||
def prep_llm(): | ||
args = aspen.LlamaModelArgs() | ||
tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') | ||
tokenizer.pad_id_ = 0 | ||
args.max_seq_len_ = 4096 | ||
args.device = args.device | ||
args.vocab_size_ = tokenizer.n_words_ | ||
args.pad_id_ = tokenizer.pad_id_ | ||
args.n_heads_ = 32 | ||
model = aspen.LlamaModel(args) | ||
aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) | ||
return tokenizer, model | ||
|
||
if __name__ == "__main__": | ||
setup_seed(42) | ||
|
||
data_set = DataSet(config, tokenizer) | ||
load_llama_7b_weight(llama_model, config["base_model"], config["device"]) | ||
init_lora_model(llama_model) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config)) | ||
|
||
step_cnt = 0 | ||
while not data_set.check_done(): | ||
optimizer.zero_grad() | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
input: MultiLoraBatchData = data_set.get_batch_data() | ||
|
||
step_cnt += 1 | ||
|
||
output = llama_model.forward(input) | ||
labels = torch.tensor(input.batch_tokens_, | ||
dtype=torch.long).to(config["device"]) | ||
|
||
total_loss = None | ||
for lora_config in input.lora_batch_data_config_: | ||
start_idx = lora_config.batch_start_idx_ | ||
end_idx = lora_config.batch_end_idx_ | ||
loss_input = output[start_idx:end_idx][..., :-1, | ||
:].contiguous().view(-1, llama_model.vocab_size_) | ||
loss_target = labels[start_idx:end_idx][..., | ||
1:].contiguous().view(-1) | ||
loss = loss_fn(loss_input, loss_target) | ||
print( | ||
f" adapter: {lora_config.adapter_name_} loss: {loss}") | ||
if total_loss is None: | ||
total_loss = loss | ||
else: | ||
total_loss += loss | ||
|
||
total_loss.backward() | ||
optimizer.step() | ||
|
||
if step_cnt % config["save_step"] == 0: | ||
for lora_config in config["lora"]: | ||
save_lora_model( | ||
llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"]) | ||
|
||
for lora_config in config["lora"]: | ||
save_lora_model( | ||
llama_model, lora_config["output"], lora_config["name"]) | ||
tokenizer, model = prep_llm() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
torch==2.0.1 | ||
sentencepiece | ||
bitsandbytes | ||
xformers | ||
einops |