Skip to content

Commit 196a018

Browse files
committed
NeurIPS camera ready
1 parent 50df2f2 commit 196a018

17 files changed

+2578
-0
lines changed

Diff for: Calculate_grad_anc.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# %%
2+
import os
3+
import transformers
4+
from transformers import AutoTokenizer, LlamaTokenizer, DataCollatorForSeq2Seq
5+
from peft import PeftModel, LoraConfig, get_peft_model
6+
7+
import sys
8+
from datasets import load_dataset
9+
from time import time
10+
from tqdm import tqdm
11+
from collections import defaultdict
12+
import pandas as pd
13+
import pickle
14+
import torch
15+
import argparse
16+
from torch.utils.data import DataLoader
17+
import pdb
18+
import json
19+
import numpy as np
20+
from utils import *
21+
from datasets import Dataset
22+
from trak.projectors import BasicProjector, CudaProjector, ProjectionType
23+
# %%
24+
PROMPT_DICT = {
25+
"prompt_full": (
26+
"Below is an instruction that describes a task, paired with an input that provides further context. "
27+
"Write a response that appropriately completes the request.\n\n"
28+
"### Instruction:\n{}\n\n### Response:\n{}"
29+
),
30+
"prompt_input": (
31+
"Below is an instruction that describes a task, paired with an input that provides further context. "
32+
"Write a response that appropriately completes the request.\n\n"
33+
"### Instruction:\n{}"
34+
)
35+
}
36+
37+
def load_model(model_path, checkpoint_path):
38+
39+
print(f"load:{model_path},{checkpoint_path}")
40+
41+
model = transformers.AutoModelForCausalLM.from_pretrained(model_path, device_map='auto',torch_dtype=torch.float16)
42+
43+
if checkpoint_path != '':
44+
model = PeftModel.from_pretrained(model, checkpoint_path, is_trainable=True)
45+
model.print_trainable_parameters()
46+
return model
47+
48+
def tokenize(tokenizer, prompt, cutoff_len=1024, add_eos_token=True):
49+
result = tokenizer(
50+
prompt,
51+
truncation=True,
52+
max_length=cutoff_len,
53+
padding=False,
54+
return_tensors=None,
55+
)
56+
if (
57+
result["input_ids"][-1] != tokenizer.eos_token_id
58+
and len(result["input_ids"]) < cutoff_len
59+
and add_eos_token
60+
):
61+
result["input_ids"].append(tokenizer.eos_token_id)
62+
result["attention_mask"].append(1)
63+
64+
result["labels"] = result["input_ids"].copy()
65+
66+
return result
67+
68+
def generate_and_tokenize_prompt(tokenizer, full_prompt, user_prompt, cutoff_len=2048, add_eos_token=True):
69+
tokenized_full_prompt = tokenize(tokenizer, full_prompt, cutoff_len, add_eos_token)
70+
71+
tokenized_user_prompt = tokenize(tokenizer, user_prompt, cutoff_len, add_eos_token=add_eos_token)
72+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
73+
74+
if add_eos_token:
75+
user_prompt_len -= 1
76+
77+
tokenized_full_prompt["labels"] = [
78+
-100
79+
] * user_prompt_len + tokenized_full_prompt["labels"][
80+
user_prompt_len:
81+
]
82+
83+
return tokenized_full_prompt
84+
def generate_grad(model,output_notation):
85+
# qkv/layer/AB selection (just for example)
86+
if output_notation=='v':
87+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and 'v_proj' in n)])
88+
elif output_notation=='q':
89+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and 'q_proj' in n)])
90+
elif output_notation=='qv':
91+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None)])
92+
elif output_notation=='A':
93+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and 'lora_A' in n)])
94+
elif output_notation=='B':
95+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and 'lora_B' in n)])
96+
elif output_notation=='layers.0':
97+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and output_notation in n)])
98+
elif output_notation=='noproj' or 'noproj_adam':
99+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None)])
100+
else: # other layers
101+
vectorized_grads = torch.cat([p.grad.cpu().view(-1) for n,p in model.named_parameters() if (p.grad is not None and output_notation in n)])
102+
return vectorized_grads
103+
104+
def prepare_optimizer_state(optimizer_state, device):
105+
avg = torch.cat([optimizer_state[n]["exp_avg"].view(-1) for n in optimizer_state.keys()])
106+
avg_sq = torch.cat([optimizer_state[n]["exp_avg_sq"].view(-1)
107+
for n in optimizer_state.keys()])
108+
avg = avg.to(device)
109+
avg_sq = avg_sq.to(device)
110+
return avg, avg_sq
111+
112+
def compute_gradient(model, train_dataset, tokenizer, max_token_length=2048, checkpoint_path=None):
113+
optimizer_path = os.path.join(checkpoint_path, "optimizer.pt")
114+
adam_optimizer_state = torch.load(optimizer_path, map_location="cpu")["state"]
115+
avg, avg_sq = prepare_optimizer_state(adam_optimizer_state, "cpu")
116+
print("m/v:" , avg,avg_sq)
117+
tr_grad_dict = {} # per-sample gradient
118+
for idx, sample in enumerate(tqdm(train_dataset)):
119+
model.eval()
120+
model.zero_grad() # zeroing out gradient
121+
full_prompt = PROMPT_DICT['prompt_full'].format(sample['input'], sample['output'])
122+
input_prompt = PROMPT_DICT['prompt_input'].format(sample['input'])
123+
tokenized_input = generate_and_tokenize_prompt(tokenizer, full_prompt, input_prompt, max_token_length)
124+
input_ids = torch.tensor(tokenized_input['input_ids']).unsqueeze(0).to('cuda')
125+
attention_mask = torch.tensor(tokenized_input['attention_mask']).unsqueeze(0).to('cuda')
126+
labels = torch.tensor(tokenized_input['labels']).unsqueeze(0).to('cuda')
127+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
128+
loss = outputs.loss
129+
loss.backward()
130+
vectorized_grads = generate_grad(model,output_notation)
131+
if output_notation=='noproj':
132+
encoded_grads=vectorized_grads
133+
elif output_notation=='noproj_adam':
134+
beta1 = 0.9
135+
beta2 = 0.999
136+
eps = 1e-08
137+
weight_decay=0
138+
139+
updated_avg = beta1 * avg + (1 - beta1) * vectorized_grads
140+
updated_avg_sq = beta2 * avg_sq + (1 - beta2) * vectorized_grads ** 2
141+
encoded_grads = updated_avg / torch.sqrt(updated_avg_sq + eps) + weight_decay*vectorized_grads #consider weight decay
142+
143+
else:
144+
if 'Project' not in globals():
145+
print("set Projector!")
146+
global Project
147+
Project=BasicProjector(grad_dim=vectorized_grads.shape[0],proj_dim=proj_d,seed=seed, proj_type=ProjectionType.rademacher, device = device,dtype=torch.float16)
148+
encoded_grads=Project.project(vectorized_grads.reshape(1,-1).to('cuda'),model_id=Project.model_id)
149+
150+
tr_grad_dict[idx] = encoded_grads
151+
print(f"anc_grad_dict[{idx}] = {tr_grad_dict[idx]}")
152+
153+
return tr_grad_dict
154+
155+
def obtain_gradients_with_adam(model, batch, avg, avg_sq):
156+
""" obtain gradients with adam optimizer states. """
157+
beta1 = 0.9
158+
beta2 = 0.999
159+
eps = 1e-08
160+
161+
vectorized_grads = torch.cat(
162+
[p.grad.view(-1) for n, p in model.named_parameters() if p.grad is not None])
163+
164+
updated_avg = beta1 * avg + (1 - beta1) * vectorized_grads
165+
updated_avg_sq = beta2 * avg_sq + (1 - beta2) * vectorized_grads ** 2
166+
vectorized_grads = updated_avg / torch.sqrt(updated_avg_sq + eps)
167+
168+
return vectorized_grads
169+
170+
# %%
171+
# tr_lang='English'
172+
output_notation='noproj_adam' #['noproj_adam','noproj','q','v','qv','A','B','layers.0','layers.31']
173+
174+
lang_list=["Chinese", "English", "French", "Japanese", "Russian", "Spanish"]
175+
checkpoint_nums=['65','131','197','260']
176+
model_name_or_path= ''
177+
max_token_length=2048
178+
proj_d=8192
179+
seed=42
180+
overwrite=False
181+
182+
for tr_lang in lang_list:
183+
checkpoint_path_list=[f'{tr_lang}/checkpoint-{num}' for num in checkpoint_nums]
184+
train_set_path= f'data/{tr_lang}_anc.json'
185+
eval_set_path_list= [f'data/{lang}_val.json' for lang in lang_list]
186+
tr_grad_file_path_list=[f'{checkpoint_path}/{output_notation}/anchor_gradients.pkl' for checkpoint_path in checkpoint_path_list]
187+
188+
if torch.cuda.is_available():
189+
device = torch.device('cuda')
190+
else:
191+
device = torch.device('cpu')
192+
193+
with open(train_set_path, 'r', encoding='utf-8') as f:
194+
train_data = json.load(f)
195+
196+
train_dataset = Dataset.from_list(train_data)
197+
198+
for checkpoint_path,tr_grad_file_path in zip(checkpoint_path_list, tr_grad_file_path_list):
199+
if os.path.exists(tr_grad_file_path) and overwrite==False:
200+
print(f"grad has already existed in {checkpoint_path}")
201+
continue
202+
else:
203+
model = load_model(model_name_or_path, checkpoint_path)
204+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,model_max_length=max_token_length)
205+
print(f"Compute gradient for checkpoint {checkpoint_path}:")
206+
tr_grad_dict = compute_gradient(model, train_dataset, tokenizer, max_token_length=max_token_length,checkpoint_path=checkpoint_path)
207+
del model
208+
# Get directory path
209+
dir_path = f'{checkpoint_path}/{output_notation}'
210+
# Create directory if it doesn't exist
211+
if not os.path.exists(dir_path):
212+
os.makedirs(dir_path)
213+
# Save anchor set gradient dictionary to pickle file
214+
with open(tr_grad_file_path, 'wb') as f:
215+
pickle.dump(tr_grad_dict, f)
216+
217+
print(f"finish grad calculation in {checkpoint_path}")

0 commit comments

Comments
 (0)