Skip to content

Commit

Permalink
Merge changes from @yezhengmao
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Aug 29, 2023
1 parent 4b41721 commit 5c60685
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions legacy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import torch
from aspen import LlamaModel, Tokenizer, DataSet
from aspen import LlamaModelArgs, MultiLoraBatchData
from aspen import load_llama_7b_weight, load_random_lora_7b_weight
from aspen import save_lora_model

import json
import torch
import torch.optim

with open('config/lora.json', 'r', encoding='utf8') as fp:
Expand Down Expand Up @@ -47,19 +48,15 @@ def init_lora_model(llama_model: LlamaModel):

torch.cuda.empty_cache()

# optim begin
optimizer = torch.optim.SGD(
llama_model.get_train_paramas(config), lr=1e-3)
# optim end
optimizer = torch.optim.AdamW(llama_model.get_train_paramas(config))

step = 0
# torch.autograd.set_detect_anomaly(True)
step_cnt = 0
while not data_set.check_done():
optimizer.zero_grad()
loss_fn = torch.nn.CrossEntropyLoss()
input: MultiLoraBatchData = data_set.get_batch_data()

step += 1
step_cnt += 1

output = llama_model.forward(input)
labels = torch.tensor(input.batch_tokens_,
Expand All @@ -84,11 +81,11 @@ def init_lora_model(llama_model: LlamaModel):
total_loss.backward()
optimizer.step()

if step % 200 == 0:
if step_cnt % config["save_step"] == 0:
for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"] + f".chk{step}", lora_config["name"])
llama_model, lora_config["output"] + f".bin{step_cnt}", lora_config["name"])

for lora_config in config["lora"]:
save_lora_model(
llama_model, lora_config["output"], lora_config["name"])
llama_model, lora_config["output"], lora_config["name"])

0 comments on commit 5c60685

Please sign in to comment.