diff --git a/aspen/model.py b/aspen/model.py index 00cb6523..3f5a75f5 100644 --- a/aspen/model.py +++ b/aspen/model.py @@ -96,16 +96,18 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(): - def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True): + def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None): + if device == None: + device = weight.device row, col = weight.shape if load_in_8bit: from bitsandbytes.nn import Linear8bitLt, Int8Params self.weight_ = Linear8bitLt( - input_features=col, output_features=row, bias=False, has_fp16_weights=False) + input_features=col, output_features=row, bias=False, has_fp16_weights=False, device=device) self.weight_.weight = Int8Params( - weight.data, requires_grad=False).cuda(weight.device) + weight.data, requires_grad=False).cuda(device) else: - self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False) + self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False, device=device) self.use_adapter_: bool = False # adapter list self.adapter_names_: Set[str] = set() diff --git a/aspen/modelloader.py b/aspen/modelloader.py index 0d2e5ae4..85d46fa4 100644 --- a/aspen/modelloader.py +++ b/aspen/modelloader.py @@ -52,8 +52,7 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str): def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False): - weight = LlamaForCausalLM.from_pretrained( - llama_model_path, device_map=dev, load_in_8bit=load_in_8bit).state_dict() + weight = LlamaForCausalLM.from_pretrained(llama_model_path).state_dict(keep_vars=True) for layer_name in weight: w: torch.Tensor = weight[layer_name] @@ -63,34 +62,34 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, loa layer_name = layer_name[len("model.layers."):] layer_id = int(layer_name[:layer_name.find(".")]) if "self_attn.q_proj" in layer_name: - model.layers_[layer_id].wq_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wq_ = Linear(w, load_in_8bit, dev) elif "self_attn.k_proj" in layer_name: - model.layers_[layer_id].wk_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wk_ = Linear(w, load_in_8bit, dev) elif "self_attn.v_proj" in layer_name: - model.layers_[layer_id].wv_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wv_ = Linear(w, load_in_8bit, dev) elif "self_attn.o_proj" in layer_name: - model.layers_[layer_id].wo_ = Linear(w, load_in_8bit) + model.layers_[layer_id].wo_ = Linear(w, load_in_8bit, dev) elif "mlp.gate_proj" in layer_name: - model.layers_[layer_id].w1_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w1_ = Linear(w, load_in_8bit, dev) elif "mlp.down_proj" in layer_name: - model.layers_[layer_id].w2_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w2_ = Linear(w, load_in_8bit, dev) elif "mlp.up_proj" in layer_name: - model.layers_[layer_id].w3_ = Linear(w, load_in_8bit) + model.layers_[layer_id].w3_ = Linear(w, load_in_8bit, dev) elif "input_layernorm" in layer_name: model.layers_[layer_id].attention_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) elif "post_attention_layernorm" in layer_name: model.layers_[layer_id].ffn_norm_ = RMSNorm( - w, model.norm_eps_) + w.to(device=dev), model.norm_eps_) else: print( f"Not use layer model.layers.{layer_name}.", file=sys.stderr) elif "embed_tokens" in layer_name: - model.token_embedding_ = w + model.token_embedding_ = w.to(device=dev) elif "norm.weight" in layer_name: - model.norm_ = RMSNorm(w, model.norm_eps_) + model.norm_ = RMSNorm(w.to(device=dev), model.norm_eps_) elif "lm_head.weight" in layer_name: - model.output_ = w.to(torch.float32) + model.output_ = w.to(dtype=torch.float32, device=dev) else: print(f"Not use layer {layer_name}.", file=sys.stderr) diff --git a/requirements.txt b/requirements.txt index 625740b2..bee2672e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ transformers bitsandbytes sentencepiece scipy +accelerate \ No newline at end of file