Skip to content

Commit

Permalink
* refactor the configure file / dataset
Browse files Browse the repository at this point in the history
* refactor the save model function
* add test case
  • Loading branch information
yezhengmao1 committed Aug 29, 2023
1 parent 4b0fefb commit 562d868
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 68 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,9 @@ cython_debug/
# ASPEN
__pycache__/
*.egg-info/
*.egg
*.egg

data/*
template/*
!data/data_demo.json
!template/template_demo.json
50 changes: 34 additions & 16 deletions aspen/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@ def __get_lora_text_data(self) -> Dict[str, List[str]]:
lora_text_data = {}
for lora_config in self.config_["lora"]:
lora_name = lora_config["name"]
lora_template = lora_config["prompt"]
data_path = lora_config["data"]

with open(lora_template, 'r', encoding='utf8') as fp:
template_config = json.load(fp)

template_parameter_list = template_config["parameter"]
template_prompt = template_config["prompt"]
template_prompt_no_input = template_config["prompt_no_input"]

lora_text_data[lora_name] = []
self.lora_cnt_epochs_[lora_name] = 0
self.lora_start_idx_[lora_name] = 0
Expand All @@ -30,18 +39,27 @@ def __get_lora_text_data(self) -> Dict[str, List[str]]:

with open(data_path, 'r', encoding='utf8') as fp:
for raw_data in json.load(fp):
raw_data_input = raw_data["input"]
raw_data_output = raw_data["output"]
raw_data_instruction = raw_data["instruction"]
text_data = ""
if raw_data_input is None or len(raw_data_input) <= 1:
text_data = lora_config["prompt_no_input"].replace(
"{output}", raw_data_output).replace("{instruction}", raw_data_instruction)
raw_data_input = {}

no_input_flag = False
for para in template_parameter_list:
if para not in raw_data:
no_input_flag = True
continue
raw_data_input[para] = raw_data[para]

text_data: str = ""
if no_input_flag:
text_data = template_prompt_no_input
else:
text_data = lora_config["prompt_input"].replace(
"{output}", raw_data_output).replace(
"{instruction}", raw_data_instruction).replace(
"{input}", raw_data_input)
text_data = template_prompt

for para in template_parameter_list:
if para not in raw_data_input:
continue
text_data = text_data.replace(
"{" + para + "}", raw_data[para])

lora_text_data[lora_name].append(text_data)

return lora_text_data
Expand Down Expand Up @@ -116,17 +134,17 @@ def get_batch_data(self) -> MultiLoraBatchData:
batch_start_idx += len(prompt_and_tokens_list)
prompts_batch_config_list.append(lora_config)

self.lora_start_idx_[lora_name] += self.lora_batch_size_[lora_name]
if self.lora_start_idx_[lora_name] >= len(self.lora_token_data_[lora_name]):
self.lora_start_idx_[lora_name] = 0
self.lora_cnt_epochs_[lora_name] += 1

print(f"{lora_name} train data:")
print(
f" epoch: {self.lora_cnt_epochs_[lora_name] + 1} / {self.lora_num_epochs_[lora_name]}")
print(
f" step : {self.lora_start_idx_[lora_name]} / {len(self.lora_token_data_[lora_name])}")

self.lora_start_idx_[lora_name] += self.lora_batch_size_[lora_name]
if self.lora_start_idx_[lora_name] >= len(self.lora_token_data_[lora_name]):
self.lora_start_idx_[lora_name] = 0
self.lora_cnt_epochs_[lora_name] += 1

# align batch data
max_token_len = math.ceil(max_token_len / 8) * 8

Expand Down
66 changes: 47 additions & 19 deletions aspen/modelloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from aspen.model import LlamaModel, Linear, RMSNorm

import os
import sys
import json
import torch
from typing import Dict

from transformers import LlamaForCausalLM

Expand Down Expand Up @@ -100,7 +103,8 @@ def load_random_lora_7b_weight(model: LlamaModel,
device: str) -> None:
norm_mean = 0
norm_std = 1e-3
target_module_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
target_module_name_list = ["q_proj", "k_proj",
"v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
for transformer_layer in model.layers_:
target_layer_list = [transformer_layer.wq_, transformer_layer.wk_,
transformer_layer.wv_, transformer_layer.wo_,
Expand All @@ -118,21 +122,45 @@ def load_random_lora_7b_weight(model: LlamaModel,
adapter_name, "lora_B", lora_b_weight)


def save_lora_model(model: LlamaModel, path: str, lora_name: str):
lora_weight_dict = {}
for idx, transformer_layer in enumerate(model.layers_):
layer_prefix_name = "base_model.model.model.layers." + \
str(idx) + "." + "self_attn."
lora_layer_list = [transformer_layer.wq_, transformer_layer.wk_,
transformer_layer.wv_, transformer_layer.wo_,
transformer_layer.w1_, transformer_layer.w2_,
transformer_layer.w3_]
lora_layer_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
for idx, lora_layer in enumerate(lora_layer_list):
if lora_name in lora_layer.loras_:
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_A.weight"] = lora_layer.loras_[lora_name].lora_a_
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_B.weight"] = lora_layer.loras_[lora_name].lora_b_

torch.save(lora_weight_dict, path)
def save_lora_model(model: LlamaModel, config: Dict[str, str]):
for lora_config in config["lora"]:
lora_name = lora_config["name"]
lora_output_dir = lora_config["output"]

if not os.path.exists(lora_output_dir):
os.makedirs(lora_output_dir)

lora_weight_dict = {}
target_modules = []
for idx, transformer_layer in enumerate(model.layers_):
layer_prefix_name = "base_model.model.model.layers." + \
str(idx) + "." + "self_attn."
lora_layer_list = [transformer_layer.wq_, transformer_layer.wk_,
transformer_layer.wv_, transformer_layer.wo_,
transformer_layer.w1_, transformer_layer.w2_,
transformer_layer.w3_]
lora_layer_name_list = [
"q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
for idx, lora_layer in enumerate(lora_layer_list):
if lora_name in lora_layer.loras_:
if lora_layer_name_list[idx] not in target_modules:
target_modules.append(lora_layer_name_list[idx])
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_A.weight"] = lora_layer.loras_[lora_name].lora_a_
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_B.weight"] = lora_layer.loras_[lora_name].lora_b_

torch.save(lora_weight_dict, lora_output_dir +
"/" + "adapter_model.bin")

adapter_config = {}
adapter_config["lora_alpha"] = lora_config["alpha"]
adapter_config["lora_dropout"] = lora_config["dropout"]
adapter_config["r"] = lora_config["r"]
adapter_config["peft_type"] = "LORA"
adapter_config["task_type"] = "CAUSAL_LM"
adapter_config["bias"] = "none"
adapter_config["target_modules"] = target_modules

with open(lora_output_dir + "/" + "adapter_config.json", "w") as f:
json.dump(adapter_config, f, indent=4)
30 changes: 17 additions & 13 deletions config/lora.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
{
"base_model": "/modules/llama-7b/7B/consolidated.00.pth",
"token_model": "/modules/llama-7b/tokenizer.model",
"cutoff_len": 512,
"base_model": "",
"token_model": "",
"cutoff_len": 256,
"group_by_length": false,
"expand_right": true,
"device": "cuda:1",
"save_step": 200,
"lora": [
{
"name": "lora_0",
"output": "lora_0.bin",
"output": "lora_0",
"batch_size": 16,
"num_epochs": 3,
"r": 8,
Expand All @@ -19,15 +19,17 @@
"q_proj": true,
"k_proj": true,
"v_proj": true,
"o_proj": true
"o_proj": true,
"w1_proj": false,
"w2_proj": false,
"w3_proj": false
},
"data": "",
"prompt_input": "",
"prompt_no_input": ""
"data": "data/data_demo.json",
"prompt": "template/template_demo.json"
},
{
"name": "lora_1",
"output": "lora_1.bin",
"output": "lora_1",
"batch_size": 16,
"num_epochs": 3,
"r": 8,
Expand All @@ -37,11 +39,13 @@
"q_proj": true,
"k_proj": true,
"v_proj": true,
"o_proj": true
"o_proj": true,
"w1_proj": false,
"w2_proj": false,
"w3_proj": false
},
"data": "",
"prompt_input": "",
"prompt_no_input": ""
"data": "data/data_demo.json",
"prompt": "template/template_demo.json"
}
]
}
11 changes: 11 additions & 0 deletions data/data_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[
{
"instruction": "Instruction demo.",
"input": "Input demo.",
"output": "Output demo."
},
{
"instruction": "Instruction demo.",
"output": "Output demo."
}
]
9 changes: 4 additions & 5 deletions legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def init_lora_model(llama_model: aspen.LlamaModel):
setup_seed(42)

data_set = aspen.DataSet(config, tokenizer)
aspen.load_llama_7b_weight(llama_model, config["base_model"], config["device"])
aspen.load_llama_7b_weight(
llama_model, config["base_model"], config["device"])
init_lora_model(llama_model)

torch.cuda.empty_cache()
Expand Down Expand Up @@ -79,8 +80,6 @@ def init_lora_model(llama_model: aspen.LlamaModel):
optimizer.step()

if step_cnt % config["save_step"] == 0:
for lora_config in config["lora"]:
aspen.save_lora_model(llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"])
aspen.save_lora_model(llama_model, config)

for lora_config in config["lora"]:
aspen.save_lora_model(llama_model, lora_config["output"], lora_config["name"])
aspen.save_lora_model(llama_model, config)
10 changes: 10 additions & 0 deletions template/template_demo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"description": "",
"parameter": [
"input",
"output",
"instruction"
],
"prompt": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n{output}\n",
"prompt_no_input": "### Instruction:\n{instruction}\n\n### Output:\n{output}\n"
}
51 changes: 37 additions & 14 deletions tests/loader_test.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
from aspen import DataSet

import json
import unittest

from typing import List

with open('config/lora.json', 'r', encoding='utf8') as fp:
config = json.load(fp)


class MockTokenizer:
def __init__(self):
self.pad_id_ = 0

def encode(self, text: str, bos: bool, eos: bool) -> List[int]:
return [0] * len(text)


class TestDataLoader(unittest.TestCase):
def test_load_llma(self):
"""
Todo: add the test function here
:return:
"""
self.assertEqual(1 + 1, 2)

def test_load_llam2(self):
"""
Todo: add the test function here
:return:
"""
self.assertEqual(2 - 1, 1)
class TestDataSet(unittest.TestCase):
def test_load_dataset(self):
dataset = DataSet(config, MockTokenizer())
input_data = dataset.get_batch_data()
self.assertEqual(len(input_data.prompts_), 4)
for p in input_data.prompts_:
if "Input" in p:
self.assertEqual(
"### Instruction:\nInstruction demo.\n\n### Input:\nInput demo.\n\n### Output:\nOutput demo.\n", p)
else:
self.assertEqual(
"### Instruction:\nInstruction demo.\n\n### Output:\nOutput demo.\n", p)
self.assertEqual(len(input_data.lora_batch_data_config_), 2)
self.assertEqual(input_data.lora_batch_data_config_[
0].batch_start_idx_, 0)
self.assertEqual(input_data.lora_batch_data_config_[
0].batch_end_idx_, 2)
self.assertEqual(input_data.lora_batch_data_config_[
1].batch_start_idx_, 2)
self.assertEqual(input_data.lora_batch_data_config_[
1].batch_end_idx_, 4)


if __name__ == '__main__':
Expand Down

0 comments on commit 562d868

Please sign in to comment.