diff --git a/README.md b/README.md index c8511ae8..9aebc00c 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ Support ```bash # Optional but recommended conda create -n aspen_env python=3.6 +conda activate aspen_env # Install requirements pip install -r requirements.txt ``` diff --git a/aspen/model.py b/aspen/model.py index 051c83de..6c3dff03 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -7,7 +7,6 @@ import xformers.ops import xformers.ops.fmha.attn_bias from typing import List, Dict, Set, Tuple -from bitsandbytes.nn import Linear8bitLt, Int8Params def precompute_rope_angle(dim: int, seq_len: int, device: str, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]: @@ -97,12 +96,18 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): - def __init__(self, weight: torch.Tensor): + def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None): + if device is None: + device = weight.device row, col = weight.shape - self.weight_ = Linear8bitLt( - input_features=col, output_features=row, bias=False, has_fp16_weights=False) - self.weight_.weight = Int8Params( - weight.data, requires_grad=False).cuda(weight.device) + if load_in_8bit: + from bitsandbytes.nn import Linear8bitLt, Int8Params + self.weight_ = Linear8bitLt( + input_features=col, output_features=row, bias=False, has_fp16_weights=False, device=device) + self.weight_.weight = Int8Params( + weight.data, requires_grad=False).cuda(device) + else: + self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False, device=device) self.use_adapter_: bool = False # adapter list self.adapter_names_: Set[str] = set() diff --git a/aspen/modelloader.py b/aspen/modelloader.py index 4477d42f..85d46fa4 100644 --- a/aspen/modelloader.py +++ b/aspen/modelloader.py @@ -51,9 +51,8 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): print(f"Not use layer {layer_name}.", file=sys.stderr) -def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): - weight = LlamaForCausalLM.from_pretrained( - llama_model_path, device_map=dev).state_dict() +def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False): + weight = LlamaForCausalLM.from_pretrained(llama_model_path).state_dict(keep_vars=True) for layer_name in weight: w: torch.Tensor = weight[layer_name] @@ -63,34 +62,34 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str): layer_name = layer_name[len("model.layers."):] layer_id = int(layer_name[:layer_name.find(".")]) if "self_attn.q_proj" in layer_name: - model.layers_[layer_id].wq_ = Linear(w) + model.layers_[layer_id].wq_ = Linear(w, load_in_8bit, dev) elif "self_attn.k_proj" in layer_name: - model.layers_[layer_id].wk_ = Linear(w) + model.layers_[layer_id].wk_ = Linear(w, load_in_8bit, dev) elif "self_attn.v_proj" in layer_name: - model.layers_[layer_id].wv_ = Linear(w) + model.layers_[layer_id].wv_ = Linear(w, load_in_8bit, dev) elif "self_attn.o_proj" in layer_name: - model.layers_[layer_id].wo_ = Linear(w) + model.layers_[layer_id].wo_ = Linear(w, load_in_8bit, dev) elif "mlp.gate_proj" in layer_name: - model.layers_[layer_id].w1_ = Linear(w) + model.layers_[layer_id].w1_ = Linear(w, load_in_8bit, dev) elif "mlp.down_proj" in layer_name: - model.layers_[layer_id].w2_ = Linear(w) + model.layers_[layer_id].w2_ = Linear(w, load_in_8bit, dev) elif "mlp.up_proj" in layer_name: - model.layers_[layer_id].w3_ = Linear(w) + model.layers_[layer_id].w3_ = Linear(w, load_in_8bit, dev) elif "input_layernorm" in layer_name: model.layers_[layer_id].attention_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) elif "post_attention_layernorm" in layer_name: model.layers_[layer_id].ffn_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) else: print( f"Not use layer model.layers.{layer_name}.", file=sys.stderr) elif "embed_tokens" in layer_name: - model.token_embedding_ = w + model.token_embedding_ = w.to(device=dev) elif "norm.weight" in layer_name: - model.norm_ = RMSNorm(w, model.norm_eps_) + model.norm_ = RMSNorm(w.to(device=dev), model.norm_eps_) elif "lm_head.weight" in layer_name: - model.output_ = w.to(torch.float32) + model.output_ = w.to(dtype=torch.float32, device=dev) else: print(f"Not use layer {layer_name}.", file=sys.stderr) diff --git a/mlora.py b/mlora.py index b877eee1..a6d71146 100644 --- a/mlora.py +++ b/mlora.py @@ -24,6 +24,7 @@ parser = argparse.ArgumentParser(description='ASPEN main program') parser.add_argument('--model_name_or_path', type=str, help='Path to or name of base model') +parser.add_argument('--load_in_8bit', type=bool, default=False, help='Load model in 8bit mode') parser.add_argument('--device', type=str, default='cuda:0', help='Specify which GPU to be used, default is cuda:0') parser.add_argument('--log', type=bool, default=True, help='Turn on or off log, default is true') @@ -43,17 +44,23 @@ def log(msg: str): exit(-1) +if args.model_name_or_path is None: + print('error: Argument --model_name_or_path are required.') + parser.print_help() + exit(-1) + + def prep_llm(): - args = aspen.LlamaModelArgs() + llama_args = aspen.LlamaModelArgs() tokenizer = aspen.Tokenizer(args.model_name_or_path + os.sep + 'tokenizer.model') tokenizer.pad_id_ = 0 - args.max_seq_len_ = 4096 - args.device = args.device - args.vocab_size_ = tokenizer.n_words_ - args.pad_id_ = tokenizer.pad_id_ - args.n_heads_ = 32 - model = aspen.LlamaModel(args) - aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device) + llama_args.max_seq_len_ = 4096 + llama_args.device = args.device + llama_args.vocab_size_ = tokenizer.n_words_ + llama_args.pad_id_ = tokenizer.pad_id_ + llama_args.n_heads_ = 32 + model = aspen.LlamaModel(llama_args) + aspen.load_llama_tf_weight(model, args.model_name_or_path, args.device, args.load_in_8bit) return tokenizer, model diff --git a/requirements.txt b/requirements.txt index 625740b2..bee2672e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ transformers bitsandbytes sentencepiece scipy +accelerate \ No newline at end of file