Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some minor improvement #10

Merged
merged 6 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@ This repository provides tools for fine-tuning large language models (LLMs) usin

- [Updates](#updates)
- [Overview](#overview)
- [Installation](#Installation)
- [Getting Started](#Quickstart)
- [Contributing](#contributing)
- [License](#license)
- [Contributing](#Contributing)
- [Copyright](#Copyright)

## Updates
Support

## Overview


## Installation
```bash
# Optional but recommended
conda create -n aspen_env python=3.6
# Install requirements
pip install -r requirements.txt
```
After installation, you can use ASPEN directly in your code:
```python
import aspen
```
## Quickstart

The `mlora.py` code is a starting point for finetuning and inference on various datasets.
Expand All @@ -27,7 +38,12 @@ python mlora.py --model_name_or_path <path_or_name>

For models larger than 13B, we recommend adjusting the learning rate:
```bash
python mlora.py –learning_rate 0.0001 --model_name_or_path <path_or_name>
python mlora.py -–learning_rate 0.0001 --model_name_or_path <path_or_name>
```

You can check detailed usage information by `--help` option:
```bash
python mlora.py --help
```

## Contributing
Expand All @@ -48,5 +64,21 @@ Submit a pull request with a detailed explanation of your changes.
}
```

## License
This project is licensed under the Apache 2.0 License - see the LICENSE file for details
## Copyright
Copyright © 2023 All Rights Reserved.

This project is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0).

```
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```
94 changes: 94 additions & 0 deletions legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import torch
from aspen import LlamaModel, Tokenizer, DataSet
from aspen import LlamaModelArgs, MultiLoraBatchData
from aspen import load_llama_7b_weight, load_random_lora_7b_weight
from aspen import save_lora_model
import torch.optim

with open('config/lora.json', 'r', encoding='utf8') as fp:
config = json.load(fp)

args = LlamaModelArgs()
tokenizer = Tokenizer(config["token_model"])
tokenizer.pad_id_ = 0
args.max_seq_len_ = 4096
args.device = config["device"]
args.vocab_size_ = tokenizer.n_words_
args.pad_id_ = tokenizer.pad_id_
args.n_heads_ = 32
llama_model = LlamaModel(args)


def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def init_lora_model(llama_model: LlamaModel):
for lora_config in config["lora"]:
load_random_lora_7b_weight(
llama_model,
lora_config["name"],
lora_config["r"],
llama_model.dim_,
lora_config["target_modules"],
llama_model.device_)
llama_model.update_lora_configure(
lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"])


if __name__ == "__main__":
setup_seed(42)

data_set = DataSet(config, tokenizer)
load_llama_7b_weight(llama_model, config["base_model"], config["device"])
init_lora_model(llama_model)

torch.cuda.empty_cache()

# optim begin
optimizer = torch.optim.SGD(
llama_model.get_train_paramas(config), lr=1e-3)
# optim end

step = 0
# torch.autograd.set_detect_anomaly(True)
while not data_set.check_done():
optimizer.zero_grad()
loss_fn = torch.nn.CrossEntropyLoss()
input: MultiLoraBatchData = data_set.get_batch_data()

step += 1

output = llama_model.forward(input)
labels = torch.tensor(input.batch_tokens_,
dtype=torch.long).to(config["device"])

total_loss = None
for lora_config in input.lora_batch_data_config_:
start_idx = lora_config.batch_start_idx_
end_idx = lora_config.batch_end_idx_
loss_input = output[start_idx:end_idx][..., :-1,
:].contiguous().view(-1, llama_model.vocab_size_)
loss_target = labels[start_idx:end_idx][...,
1:].contiguous().view(-1)
loss = loss_fn(loss_input, loss_target)
print(
f" adapter: {lora_config.adapter_name_} loss: {loss}")
if total_loss is None:
total_loss = loss
else:
total_loss += loss

total_loss.backward()
optimizer.step()

if step % 200 == 0:
for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"] + f".chk{step}", lora_config["name"])

for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"], lora_config["name"])
147 changes: 56 additions & 91 deletions mlora.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,59 @@
import json
# ASPEN: Efficient Multi-LoRA Fine Tuning with Shared-Based Model
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2023 All Rights Reserved.
#
# Email:
# Github: https://github.com/TUDB-Labs/multi-lora-fine-tune
# Website:

import datetime
import argparse
import torch
from aspen import LlamaModel, Tokenizer, DataSet
from aspen import LlamaModelArgs, MultiLoraBatchData
from aspen import load_llama_7b_weight, load_random_lora_7b_weight
from aspen import save_lora_model
import torch.optim

with open('config/lora.json', 'r', encoding='utf8') as fp:
config = json.load(fp)

args = LlamaModelArgs()
tokenizer = Tokenizer(config["token_model"])
tokenizer.pad_id_ = 0
args.max_seq_len_ = 4096
args.device = config["device"]
args.vocab_size_ = tokenizer.n_words_
args.pad_id_ = tokenizer.pad_id_
args.n_heads_ = 32
llama_model = LlamaModel(args)


def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def init_lora_model(llama_model: LlamaModel):
for lora_config in config["lora"]:
load_random_lora_7b_weight(
llama_model,
lora_config["name"],
lora_config["r"],
llama_model.dim_,
lora_config["target_modules"],
llama_model.device_)
llama_model.update_lora_configure(
lora_config["name"], lora_config["r"], lora_config["alpha"], lora_config["dropout"])

import aspen
import os

parser = argparse.ArgumentParser(description='ASPEN main program')
parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model')
parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0')
parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true')

args = parser.parse_args()

def log(msg:str):
if args.log:
print('[%s] ASPEN: %s' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg))

if torch.cuda.is_available():
log('NVIDIA CUDA initialized successfully.')
log('Total %i GPU(s) detected.' % torch.cuda.device_count())
else:
print('ASPEN requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.')
exit(-1)

def prep_llm():
args = aspen.LlamaModelArgs()
tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model')
tokenizer.pad_id_ = 0
args.max_seq_len_ = 4096
args.device = args.device
args.vocab_size_ = tokenizer.n_words_
args.pad_id_ = tokenizer.pad_id_
args.n_heads_ = 32
model = aspen.LlamaModel(args)
aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device)
return tokenizer, model

if __name__ == "__main__":
setup_seed(42)

data_set = DataSet(config, tokenizer)
load_llama_7b_weight(llama_model, config["base_model"], config["device"])
init_lora_model(llama_model)

torch.cuda.empty_cache()

# optim begin
optimizer = torch.optim.SGD(
llama_model.get_train_paramas(config), lr=1e-3)
# optim end

step = 0
# torch.autograd.set_detect_anomaly(True)
while not data_set.check_done():
optimizer.zero_grad()
loss_fn = torch.nn.CrossEntropyLoss()
input: MultiLoraBatchData = data_set.get_batch_data()

step += 1

output = llama_model.forward(input)
labels = torch.tensor(input.batch_tokens_,
dtype=torch.long).to(config["device"])

total_loss = None
for lora_config in input.lora_batch_data_config_:
start_idx = lora_config.batch_start_idx_
end_idx = lora_config.batch_end_idx_
loss_input = output[start_idx:end_idx][..., :-1,
:].contiguous().view(-1, llama_model.vocab_size_)
loss_target = labels[start_idx:end_idx][...,
1:].contiguous().view(-1)
loss = loss_fn(loss_input, loss_target)
print(
f" adapter: {lora_config.adapter_name_} loss: {loss}")
if total_loss is None:
total_loss = loss
else:
total_loss += loss

total_loss.backward()
optimizer.step()

if step % 200 == 0:
for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"] + f".chk{step}", lora_config["name"])

for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"], lora_config["name"])
tokenizer, model = prep_llm()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch==2.0.1
sentencepiece
bitsandbytes
xformers
einops