Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 8bit loading mode for transformers models #15

Merged
merged 13 commits into from
Aug 30, 2023
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Support
```bash
# Optional but recommended
conda create -n aspen_env python=3.6
conda activate aspen_env
# Install requirements
pip install -r requirements.txt
```
Expand Down
17 changes: 11 additions & 6 deletions aspen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import xformers.ops
import xformers.ops.fmha.attn_bias
from typing import List, Dict, Set, Tuple
from bitsandbytes.nn import Linear8bitLt, Int8Params


def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -97,12 +96,18 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class Linear():
def __init__(self, weight: torch.Tensor):
def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None):
if device is None:
device = weight.device
row, col = weight.shape
self.weight_ = Linear8bitLt(
input_features=col, output_features=row, bias=False, has_fp16_weights=False)
self.weight_.weight = Int8Params(
weight.data, requires_grad=False).cuda(weight.device)
if load_in_8bit:
from bitsandbytes.nn import Linear8bitLt, Int8Params
self.weight_ = Linear8bitLt(
input_features=col, output_features=row, bias=False, has_fp16_weights=False, device=device)
self.weight_.weight = Int8Params(
weight.data, requires_grad=False).cuda(device)
else:
self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False, device=device)
self.use_adapter_: bool = False
# adapter list
self.adapter_names_: Set[str] = set()
Expand Down
29 changes: 14 additions & 15 deletions aspen/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str):
print(f"Not use layer {layer_name}.", file=sys.stderr)


def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str):
weight = LlamaForCausalLM.from_pretrained(
llama_model_path, device_map=dev).state_dict()
def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False):
weight = LlamaForCausalLM.from_pretrained(llama_model_path).state_dict(keep_vars=True)

for layer_name in weight:
w: torch.Tensor = weight[layer_name]
Expand All @@ -63,34 +62,34 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str):
layer_name = layer_name[len("model.layers."):]
layer_id = int(layer_name[:layer_name.find(".")])
if "self_attn.q_proj" in layer_name:
model.layers_[layer_id].wq_ = Linear(w)
model.layers_[layer_id].wq_ = Linear(w, load_in_8bit, dev)
elif "self_attn.k_proj" in layer_name:
model.layers_[layer_id].wk_ = Linear(w)
model.layers_[layer_id].wk_ = Linear(w, load_in_8bit, dev)
elif "self_attn.v_proj" in layer_name:
model.layers_[layer_id].wv_ = Linear(w)
model.layers_[layer_id].wv_ = Linear(w, load_in_8bit, dev)
elif "self_attn.o_proj" in layer_name:
model.layers_[layer_id].wo_ = Linear(w)
model.layers_[layer_id].wo_ = Linear(w, load_in_8bit, dev)
elif "mlp.gate_proj" in layer_name:
model.layers_[layer_id].w1_ = Linear(w)
model.layers_[layer_id].w1_ = Linear(w, load_in_8bit, dev)
elif "mlp.down_proj" in layer_name:
model.layers_[layer_id].w2_ = Linear(w)
model.layers_[layer_id].w2_ = Linear(w, load_in_8bit, dev)
elif "mlp.up_proj" in layer_name:
model.layers_[layer_id].w3_ = Linear(w)
model.layers_[layer_id].w3_ = Linear(w, load_in_8bit, dev)
elif "input_layernorm" in layer_name:
model.layers_[layer_id].attention_norm_ = RMSNorm(
w, model.norm_eps_)
w.to(device=dev), model.norm_eps_)
elif "post_attention_layernorm" in layer_name:
model.layers_[layer_id].ffn_norm_ = RMSNorm(
w, model.norm_eps_)
w.to(device=dev), model.norm_eps_)
else:
print(
f"Not use layer model.layers.{layer_name}.", file=sys.stderr)
elif "embed_tokens" in layer_name:
model.token_embedding_ = w
model.token_embedding_ = w.to(device=dev)
elif "norm.weight" in layer_name:
model.norm_ = RMSNorm(w, model.norm_eps_)
model.norm_ = RMSNorm(w.to(device=dev), model.norm_eps_)
elif "lm_head.weight" in layer_name:
model.output_ = w.to(torch.float32)
model.output_ = w.to(dtype=torch.float32, device=dev)
else:
print(f"Not use layer {layer_name}.", file=sys.stderr)

Expand Down
23 changes: 15 additions & 8 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

parser = argparse.ArgumentParser(description='ASPEN main program')
parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model')
parser.add_argument('--load_in_8bit', type=bool, default=False, help='Load model in 8bit mode')
parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0')
parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true')

Expand All @@ -43,17 +44,23 @@ def log(msg: str):
exit(-1)


if args.model_name_or_path is None:
print('error: Argument --model_name_or_path are required.')
parser.print_help()
exit(-1)


def prep_llm():
args = aspen.LlamaModelArgs()
llama_args = aspen.LlamaModelArgs()
tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model')
tokenizer.pad_id_ = 0
args.max_seq_len_ = 4096
args.device = args.device
args.vocab_size_ = tokenizer.n_words_
args.pad_id_ = tokenizer.pad_id_
args.n_heads_ = 32
model = aspen.LlamaModel(args)
aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device)
llama_args.max_seq_len_ = 4096
llama_args.device = args.device
llama_args.vocab_size_ = tokenizer.n_words_
llama_args.pad_id_ = tokenizer.pad_id_
llama_args.n_heads_ = 32
model = aspen.LlamaModel(llama_args)
aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device, args.load_in_8bit)
return tokenizer, model


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ transformers
bitsandbytes
sentencepiece
scipy
accelerate