Skip to content

Commit

Permalink
Fix unresonable memory allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Aug 30, 2023
1 parent aed5484 commit 05b5cf2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
10 changes: 6 additions & 4 deletions aspen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class Linear():
def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True):
def __init__(self, weight: torch.Tensor, load_in_8bit: bool = True, device: str = None):
if device == None:
device = weight.device
row, col = weight.shape
if load_in_8bit:
from bitsandbytes.nn import Linear8bitLt, Int8Params
self.weight_ = Linear8bitLt(
input_features=col, output_features=row, bias=False, has_fp16_weights=False)
input_features=col, output_features=row, bias=False, has_fp16_weights=False, device=device)
self.weight_.weight = Int8Params(
weight.data, requires_grad=False).cuda(weight.device)
weight.data, requires_grad=False).cuda(device)
else:
self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False)
self.weight_ = torch.nn.Linear(in_features=col, out_features=row, bias=False, device=device)
self.use_adapter_: bool = False
# adapter list
self.adapter_names_: Set[str] = set()
Expand Down
27 changes: 13 additions & 14 deletions aspen/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def load_llama_7b_weight(model: LlamaModel, llama_model_path: str, device: str):


def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, load_in_8bit: bool = False):
weight = LlamaForCausalLM.from_pretrained(
llama_model_path, device_map=dev, load_in_8bit=load_in_8bit).state_dict()
weight = LlamaForCausalLM.from_pretrained(llama_model_path).state_dict(keep_vars=True)

for layer_name in weight:
w: torch.Tensor = weight[layer_name]
Expand All @@ -63,34 +62,34 @@ def load_llama_tf_weight(model: LlamaModel, llama_model_path: str, dev: str, loa
layer_name = layer_name[len("model.layers."):]
layer_id = int(layer_name[:layer_name.find(".")])
if "self_attn.q_proj" in layer_name:
model.layers_[layer_id].wq_ = Linear(w, load_in_8bit)
model.layers_[layer_id].wq_ = Linear(w, load_in_8bit, dev)
elif "self_attn.k_proj" in layer_name:
model.layers_[layer_id].wk_ = Linear(w, load_in_8bit)
model.layers_[layer_id].wk_ = Linear(w, load_in_8bit, dev)
elif "self_attn.v_proj" in layer_name:
model.layers_[layer_id].wv_ = Linear(w, load_in_8bit)
model.layers_[layer_id].wv_ = Linear(w, load_in_8bit, dev)
elif "self_attn.o_proj" in layer_name:
model.layers_[layer_id].wo_ = Linear(w, load_in_8bit)
model.layers_[layer_id].wo_ = Linear(w, load_in_8bit, dev)
elif "mlp.gate_proj" in layer_name:
model.layers_[layer_id].w1_ = Linear(w, load_in_8bit)
model.layers_[layer_id].w1_ = Linear(w, load_in_8bit, dev)
elif "mlp.down_proj" in layer_name:
model.layers_[layer_id].w2_ = Linear(w, load_in_8bit)
model.layers_[layer_id].w2_ = Linear(w, load_in_8bit, dev)
elif "mlp.up_proj" in layer_name:
model.layers_[layer_id].w3_ = Linear(w, load_in_8bit)
model.layers_[layer_id].w3_ = Linear(w, load_in_8bit, dev)
elif "input_layernorm" in layer_name:
model.layers_[layer_id].attention_norm_ = RMSNorm(
w, model.norm_eps_)
w.to(device=dev), model.norm_eps_)
elif "post_attention_layernorm" in layer_name:
model.layers_[layer_id].ffn_norm_ = RMSNorm(
w, model.norm_eps_)
w.to(device=dev), model.norm_eps_)
else:
print(
f"Not use layer model.layers.{layer_name}.", file=sys.stderr)
elif "embed_tokens" in layer_name:
model.token_embedding_ = w
model.token_embedding_ = w.to(device=dev)
elif "norm.weight" in layer_name:
model.norm_ = RMSNorm(w, model.norm_eps_)
model.norm_ = RMSNorm(w.to(device=dev), model.norm_eps_)
elif "lm_head.weight" in layer_name:
model.output_ = w.to(torch.float32)
model.output_ = w.to(dtype=torch.float32, device=dev)
else:
print(f"Not use layer {layer_name}.", file=sys.stderr)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ transformers
bitsandbytes
sentencepiece
scipy
accelerate

0 comments on commit 05b5cf2

Please sign in to comment.