Skip to content

Commit

Permalink
bug fixed: adamw 优化器下 nan 问题,修改 fp16 精度修改为 fp32
Browse files Browse the repository at this point in the history
* 重构一些代码坏味道
  • Loading branch information
yezhengmao1 committed Aug 29, 2023
1 parent 9202016 commit f39502c
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 178 deletions.
116 changes: 63 additions & 53 deletions aspen/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from aspen.modelargs import LlamaModelArgs, MultiLoraBatchData
from aspen import LlamaModelArgs, MultiLoraBatchData

import time
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand All @@ -18,13 +17,13 @@ def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10
emb = torch.outer(seq, angles).float()
emb = einops.repeat(emb, "... n -> ... (n r)", r=2)
# cos(angle), sin(angle)
return (emb.cos().to(torch.float16), emb.sin().to(torch.float16))
return (emb.cos().to(torch.float32), emb.sin().to(torch.float32))


def precompute_mask(input: MultiLoraBatchData, n_head: int, device: str) -> torch.Tensor:
mask = torch.full((len(input.prompts_), n_head,
input.batch_seq_len_, input.batch_seq_len_), float("-inf"))
mask = torch.triu(mask, diagonal=1).to(torch.float16).cuda(device)
mask = torch.triu(mask, diagonal=1).to(torch.float32).cuda(device)

for idx, _ in enumerate(input.prompts_):
zero_len = input.tokens_len_without_pad_[idx]
Expand All @@ -36,7 +35,7 @@ def precompute_mask(input: MultiLoraBatchData, n_head: int, device: str) -> torc
mask[idx] += torch.tensor([float("-inf")] * inf_len + [0] * zero_len).expand(
input.batch_seq_len_, input.batch_seq_len_).cuda(device)

return mask.to(torch.float16)
return mask.to(torch.float32)


def rotate_half(x: torch.Tensor) -> torch.Tensor:
Expand All @@ -61,7 +60,7 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, angle: Tuple[torch.Tens
class RMSNorm():
def __init__(self, weight: torch.Tensor, eps: float = 1e-06):
self.norm_eps_ = eps
self.weight_ = weight
self.weight_ = weight.to(torch.float32)

def _norm(self, data: torch.Tensor) -> torch.Tensor:
return data * torch.rsqrt(data.pow(2).mean(-1, keepdim=True) + self.norm_eps_)
Expand All @@ -70,6 +69,32 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
return self._norm(data.float()).type_as(data) * self.weight_


class Lora():
def __init__(self, adapter_name: str):
self.adapter_name_: str = adapter_name

self.lora_a_: torch.Tensor = None
self.lora_b_: torch.Tensor = None

self.r_: int = 0
self.alpha_: int = 0
self.dropout_: float = 0.0
self.scaling_: float = 0.0

def set_parameter(self, r: int, alpha: int, dropout: float):
self.r_ = r
self.alpha_ = alpha
self.dropout_ = dropout
self.scaling_ = alpha / r

def forward(self, data: torch.Tensor) -> torch.Tensor:
data_ = F.dropout(data, self.dropout_)
data_ @= self.lora_a_.transpose(0, 1)
data_ @= self.lora_b_.transpose(0, 1)
data_ *= self.scaling_
return data_


class Linear():
def __init__(self, weight: torch.Tensor):
row, col = weight.shape
Expand All @@ -80,31 +105,25 @@ def __init__(self, weight: torch.Tensor):
self.use_adapter_: bool = False
# adapter list
self.adapter_names_: Set[str] = set()
# lora weight
self.lora_a_: Dict[str, torch.Tensor] = {} # r * dim
self.lora_b_: Dict[str, torch.Tensor] = {} # dim * r
# common paramas
self.lora_dropout_: Dict[str, float] = {}
self.r_: Dict[str, int] = {}
self.lora_alpha_: Dict[str, int] = {}
self.scaling_: Dict[str, float] = {}

def update_layer(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float):
if len(self.adapter_names_) <= 0:
self.loras_: Dict[str, Lora] = {}

def set_lora_layer_parameter(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float):
if len(self.adapter_names_) <= 0 or not self.use_adapter_:
return

self.r_[adapter_name] = r
self.lora_alpha_[adapter_name] = lora_alpha
self.lora_dropout_[adapter_name] = lora_dropout
self.scaling_[adapter_name] = lora_alpha / r
self.loras_[adapter_name].set_parameter(r, lora_alpha, lora_dropout)

def set_lora_layer_weight(self, adapter_name: str, lora_name: str, weight: torch.Tensor):
if adapter_name not in self.loras_:
self.loras_[adapter_name] = Lora(adapter_name)

def update_lora_weight(self, adapter_name: str, lora_name: str, weight: torch.Tensor):
if lora_name == "lora_A":
self.lora_a_[adapter_name] = weight
self.loras_[adapter_name].lora_a_ = weight
elif lora_name == "lora_B":
self.lora_b_[adapter_name] = weight
self.loras_[adapter_name].lora_b_ = weight
else:
raise (f"No lora_name {lora_name}")

self.adapter_names_.add(adapter_name)

def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor:
Expand All @@ -123,12 +142,8 @@ def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.T
if adapter_name == "":
continue

data_ = F.dropout(data[start_idx: end_idx],
self.lora_dropout_[adapter_name])
data_ @= self.lora_a_[adapter_name].transpose(0, 1)
data_ @= self.lora_b_[adapter_name].transpose(0, 1)
data_ *= self.scaling_[adapter_name]
result[start_idx: end_idx] += data_
result[start_idx: end_idx] += self.loras_[
adapter_name].forward(data[start_idx:end_idx])

return result

Expand All @@ -154,14 +169,12 @@ def __init__(self, layer_id: int, args: LlamaModelArgs):
self.n_heads_ = args.n_heads_
self.head_dim_ = args.dim_ // args.n_heads_

def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float):
self.wk_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.wq_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.wv_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.wo_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.w1_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.w2_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
self.w3_.update_layer(adapter_name, r, lora_alpha, lora_dropout)
def set_lora_parameter(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float):
linear_layer_list = [self.wk_, self.wq_, self.wv_,
self.wo_, self.w1_, self.w2_, self.w3_]
for linear_layer in linear_layer_list:
linear_layer.set_lora_layer_parameter(
adapter_name, r, lora_alpha, lora_dropout)

# @torch.compile
def forward(self, data: torch.Tensor, mask: torch.Tensor, rope_angle: Tuple[torch.Tensor, torch.Tensor], input_args: MultiLoraBatchData):
Expand Down Expand Up @@ -229,8 +242,8 @@ def __init__(self, args: LlamaModelArgs):
self.dim_ = args.dim_

def update_lora_configure(self, adapter_name: str, r: int, lora_alpha: int, lora_dropout: float):
for layer in self.layers_:
layer.update_lora_configure(
for transformer_layer in self.layers_:
transformer_layer.set_lora_parameter(
adapter_name, r, lora_alpha, lora_dropout)

def forward(self, input: MultiLoraBatchData):
Expand All @@ -256,19 +269,16 @@ def forward_for_checkpoint(*inputs):

def get_train_paramas(self, config: Dict[str, str]) -> List[int]:
train_paramas = []
for layer in self.layers_:
for transformer_layer in self.layers_:
for lora_config in config["lora"]:
adapter_name = lora_config["name"]
if adapter_name in layer.wq_.lora_a_:
train_paramas.append(layer.wq_.lora_a_[adapter_name])
train_paramas.append(layer.wq_.lora_b_[adapter_name])
if adapter_name in layer.wk_.lora_a_:
train_paramas.append(layer.wk_.lora_a_[adapter_name])
train_paramas.append(layer.wk_.lora_b_[adapter_name])
if adapter_name in layer.wv_.lora_a_:
train_paramas.append(layer.wv_.lora_a_[adapter_name])
train_paramas.append(layer.wv_.lora_b_[adapter_name])
if adapter_name in layer.wo_.lora_a_:
train_paramas.append(layer.wo_.lora_a_[adapter_name])
train_paramas.append(layer.wo_.lora_b_[adapter_name])
lora_layer_list = [transformer_layer.wq_.loras_, transformer_layer.wk_.loras_,
transformer_layer.wv_.loras_, transformer_layer.wo_.loras_,
transformer_layer.w1_.loras_, transformer_layer.w2_.loras_,
transformer_layer.w3_.loras_]

for lora_layer in lora_layer_list:
if adapter_name in lora_layer:
train_paramas.append(lora_layer[adapter_name].lora_a_)
train_paramas.append(lora_layer[adapter_name].lora_b_)
return train_paramas
136 changes: 30 additions & 106 deletions aspen/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str):
elif "norm.weight" in layer_name:
model.norm_ = RMSNorm(w, model.norm_eps_)
elif "output.weight" in layer_name:
model.output_ = w
model.output_ = w.to(torch.float32)
else:
print(f"Not use layer {layer_name}.", file=sys.stderr)

Expand Down Expand Up @@ -86,123 +86,47 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str):
elif "norm.weight" in layer_name:
model.norm_ = RMSNorm(w, model.norm_eps_)
elif "lm_head.weight" in layer_name:
model.output_ = w
model.output_ = w.to(torch.float32)
else:
print(f"Not use layer {layer_name}.", file=sys.stderr)


def load_alpaca_lora_7b_weight(model: LlamaModel, lora_model_path: str, adapter_name: str, device: str):
lora_weight = torch.load(
lora_model_path, map_location=torch.device(device))
for layer_name in lora_weight:
w: torch.Tensor = lora_weight[layer_name].to(torch.float16)
w.requires_grad_(True)

layer_name = layer_name[len("base_model.model.model.layers."):]
layer_id = int(layer_name[:layer_name.find(".")])
lora_name = ""
if "lora_A" in layer_name:
lora_name = "lora_A"
elif "lora_B" in layer_name:
lora_name = "lora_B"

if "q_proj" in layer_name:
model.layers_[layer_id].wq_.update_lora_weight(
adapter_name, lora_name, w)
model.layers_[layer_id].wq_.use_adapter_ = True
elif "k_proj" in layer_name:
model.layers_[layer_id].wk_.update_lora_weight(
adapter_name, lora_name, w)
model.layers_[layer_id].wk_.use_adapter_ = True
elif "v_proj" in layer_name:
model.layers_[layer_id].wv_.update_lora_weight(
adapter_name, lora_name, w)
model.layers_[layer_id].wv_.use_adapter_ = True
elif "o_proj" in layer_name:
model.layers_[layer_id].wo_.update_lora_weight(
adapter_name, lora_name, w)
model.layers_[layer_id].wo_.use_adapter_ = True
else:
print(f"Not user layer {layer_name}")


def load_random_lora_7b_weight(model: LlamaModel, adapter_name: str, r: int, dim: int, target_module: str, device: str) -> None:
norm_mean = 0
norm_std = 1e-3
for layer in model.layers_:
if target_module["q_proj"] is True:
wq_lora_a_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16)
wq_lora_b_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16)
layer.wq_.update_lora_weight(
adapter_name, "lora_A", wq_lora_a_weight)
layer.wq_.update_lora_weight(
adapter_name, "lora_B", wq_lora_b_weight)
layer.wq_.use_adapter_ = True

if target_module["k_proj"] is True:
wk_lora_a_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16)
wk_lora_b_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16)
layer.wk_.update_lora_weight(
adapter_name, "lora_A", wk_lora_a_weight)
layer.wk_.update_lora_weight(
adapter_name, "lora_B", wk_lora_b_weight)
layer.wk_.use_adapter_ = True

if target_module["v_proj"] is True:
wv_lora_a_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16)
wv_lora_b_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16)
layer.wv_.update_lora_weight(
adapter_name, "lora_A", wv_lora_a_weight)
layer.wv_.update_lora_weight(
adapter_name, "lora_B", wv_lora_b_weight)
layer.wv_.use_adapter_ = True

if target_module["o_proj"] is True:
wo_lora_a_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float16)
wo_lora_b_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float16)
layer.wo_.update_lora_weight(
adapter_name, "lora_A", wo_lora_a_weight)
layer.wo_.update_lora_weight(
adapter_name, "lora_B", wo_lora_b_weight)
layer.wo_.use_adapter_ = True
target_module_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
for transformer_layer in model.layers_:
target_layer_list = [transformer_layer.wq_, transformer_layer.wk_,
transformer_layer.wv_, transformer_layer.wo_,
transformer_layer.w1_, transformer_layer.w2_,
transformer_layer.w3_]
for idx, module_name in enumerate(target_module_name_list):
if module_name in target_module and target_module[module_name]:
lora_a_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(r, dim), device=device, requires_grad=True, dtype=torch.float32)
lora_b_weight = torch.normal(
mean=norm_mean, std=norm_std, size=(dim, r), device=device, requires_grad=True, dtype=torch.float32)
target_layer_list[idx].set_lora_layer_weight(
adapter_name, "lora_A", lora_a_weight)
target_layer_list[idx].set_lora_layer_weight(
adapter_name, "lora_B", lora_b_weight)


def save_lora_model(model: LlamaModel, path: str, lora_name: str):
lora_weight_dict = {}
for idx, layer in enumerate(model.layers_):
for idx, transformer_layer in enumerate(model.layers_):
layer_prefix_name = "base_model.model.model.layers." + \
str(idx) + "." + "self_attn."
if lora_name in layer.wq_.lora_a_:
lora_weight_dict[layer_prefix_name +
"q_proj.lora_A.weight"] = layer.wq_.lora_a_[lora_name]
if lora_name in layer.wq_.lora_b_:
lora_weight_dict[layer_prefix_name +
"q_proj.lora_B.weight"] = layer.wq_.lora_b_[lora_name]
if lora_name in layer.wk_.lora_a_:
lora_weight_dict[layer_prefix_name +
"k_proj.lora_A.weigth"] = layer.wk_.lora_a_[lora_name]
if lora_name in layer.wk_.lora_b_:
lora_weight_dict[layer_prefix_name +
"k_proj.lora_B.weight"] = layer.wk_.lora_b_[lora_name]
if lora_name in layer.wv_.lora_a_:
lora_weight_dict[layer_prefix_name +
"v_proj.lora_A.weight"] = layer.wv_.lora_a_[lora_name]
if lora_name in layer.wv_.lora_b_:
lora_weight_dict[layer_prefix_name +
"v_proj.lora_B.weight"] = layer.wv_.lora_b_[lora_name]
if lora_name in layer.wo_.lora_a_:
lora_weight_dict[layer_prefix_name +
"o_proj.lora_A.weight"] = layer.wo_.lora_a_[lora_name]
if lora_name in layer.wo_.lora_b_:
lora_weight_dict[layer_prefix_name +
"o_proj.lora_B.weight"] = layer.wo_.lora_b_[lora_name]
lora_layer_list = [transformer_layer.wq_, transformer_layer.wk_,
transformer_layer.wv_, transformer_layer.wo_,
transformer_layer.w1_, transformer_layer.w2_,
transformer_layer.w3_]
lora_layer_name_list = ["q_proj", "k_proj", "v_proj", "o_proj", "w1_proj", "w2_proj", "w3_proj"]
for idx, lora_layer in enumerate(lora_layer_list):
if lora_name in lora_layer.loras_:
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_A.weight"] = lora_layer.loras_[lora_name].lora_a_
lora_weight_dict[layer_prefix_name +
f"{lora_layer_name_list[idx]}.lora_B.weight"] = lora_layer.loras_[lora_name].lora_b_

torch.save(lora_weight_dict, path)
17 changes: 9 additions & 8 deletions config/lora.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
{
"base_model": "/yezhengmao/modules/llama-7b/7B/consolidated.00.pth",
"token_model": "/yezhengmao/modules/llama-7b/tokenizer.model",
"base_model": "",
"token_model": "",
"cutoff_len": 512,
"group_by_length": false,
"expand_right": true,
"device": "cuda:1",
"save_step": 200,
"lora": [
{
"name": "lora_0",
Expand All @@ -20,9 +21,9 @@
"v_proj": true,
"o_proj": true
},
"data": "data/train_lora_a.json",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n"
"data": "",
"prompt_input": "",
"prompt_no_input": ""
},
{
"name": "lora_1",
Expand All @@ -38,9 +39,9 @@
"v_proj": true,
"o_proj": true
},
"data": "data/train_lora_b.json",
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\n\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}\n\n"
"data": "",
"prompt_input": "",
"prompt_no_input": ""
}
]
}
Loading

0 comments on commit f39502c

Please sign in to comment.