-
Notifications
You must be signed in to change notification settings - Fork 124
/
evaluate.py
221 lines (201 loc) · 7.25 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import json
import os
import shutil
from pathlib import Path
from time import time
from typing import Tuple, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from baselines.efk import EFKHyperParams, EfkRewriteExecutor
from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.kn import KNHyperParams, apply_kn_to_model
from baselines.mend import MENDHyperParams, MendRewriteExecutor
from dsets import (
AttributeSnippets,
CounterFactDataset,
MENDQADataset,
get_tfidf_vectorizer,
)
from experiments.py.eval_utils_counterfact import compute_rewrite_quality_counterfact
from experiments.py.eval_utils_zsre import compute_rewrite_quality_zsre
from rome import ROMEHyperParams, apply_rome_to_model
from util import nethook
from util.globals import *
ALG_DICT = {
"ROME": (ROMEHyperParams, apply_rome_to_model),
"FT": (FTHyperParams, apply_ft_to_model),
"KN": (KNHyperParams, apply_kn_to_model),
"MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model),
"KE": (EFKHyperParams, EfkRewriteExecutor().apply_to_model),
}
DS_DICT = {
"cf": (CounterFactDataset, compute_rewrite_quality_counterfact),
"zsre": (MENDQADataset, compute_rewrite_quality_zsre),
}
def main(
alg_name: str,
model_name: Union[str, Tuple],
hparams_fname: str,
ds_name: str,
dataset_size_limit: int,
continue_from_run: str,
skip_generation_tests: bool,
conserve_memory: bool,
dir_name: str,
):
# Set algorithm-specific variables
params_class, apply_algo = ALG_DICT[alg_name]
# Determine run directory
if continue_from_run is not None:
run_dir = RESULTS_DIR / dir_name / continue_from_run
assert (
run_dir.exists()
), f"If continuing from run, {continue_from_run} must exist!"
else:
alg_dir = RESULTS_DIR / dir_name
if alg_dir.exists():
id_list = [
int(str(x).split("_")[-1])
for x in alg_dir.iterdir()
if str(x).split("_")[-1].isnumeric()
]
run_id = 0 if not id_list else max(id_list) + 1
else:
run_id = 0
run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
run_dir.mkdir(parents=True, exist_ok=True)
print(f"Results will be stored at {run_dir}")
# Get run hyperparameters
params_path = (
run_dir / "params.json"
if continue_from_run is not None
else HPARAMS_DIR / alg_name / hparams_fname
)
hparams = params_class.from_json(params_path)
if not (run_dir / "params.json").exists():
shutil.copyfile(params_path, run_dir / "params.json")
print(f"Executing {alg_name} with parameters {hparams}")
# Instantiate vanilla model
print("Instantiating model")
if type(model_name) is str:
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tok = AutoTokenizer.from_pretrained(model_name)
tok.pad_token = tok.eos_token
else:
model, tok = model_name
# Load data
print("Loading dataset, attribute snippets, tf-idf data")
snips = AttributeSnippets(DATA_DIR) if not skip_generation_tests else None
vec = get_tfidf_vectorizer(DATA_DIR) if not skip_generation_tests else None
ds_class, ds_eval_method = DS_DICT[ds_name]
ds = ds_class(DATA_DIR, size=dataset_size_limit, tok=tok)
# Iterate through dataset
for record in ds:
case_id = record["case_id"]
case_result_path = run_dir / f"case_{case_id}.json"
if not case_result_path.exists():
# Compute weight changes + record weights that changed
start = time()
args_conserve_memory = (
dict(return_orig_weights_device=("cpu" if conserve_memory else "cuda"))
if conserve_memory
else dict()
)
edited_model, weights_copy = apply_algo(
model,
tok,
[record["requested_rewrite"]],
hparams,
copy=False,
return_orig_weights=True,
**args_conserve_memory,
)
exec_time = time() - start
print("Execution took", exec_time)
# Execute evaluation suite
start = time()
metrics = {
"case_id": case_id,
"requested_rewrite": record["requested_rewrite"],
"time": exec_time,
"post": ds_eval_method(edited_model, tok, record, snips, vec),
}
with torch.no_grad():
for k, v in weights_copy.items():
nethook.get_parameter(model, k)[...] = v.to("cuda")
metrics["pre"] = ds_eval_method(model, tok, record, snips, vec)
print("Evaluation took", time() - start)
# Dump metrics in .json
with open(case_result_path, "w") as f:
json.dump(metrics, f, indent=1)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--alg_name",
choices=["ROME", "FT", "KN", "MEND", "KE"],
default="ROME",
help="Editing algorithm to use. Results are saved in results/<alg_name>/<run_id>, "
"where a new run_id is generated on each run. "
"If continuing from previous run, specify the run_id in --continue_from_run.",
required=True,
)
parser.add_argument(
"--model_name",
choices=["gpt2-medium", "gpt2-large", "gpt2-xl", "EleutherAI/gpt-j-6B"],
default="gpt2-xl",
help="Model to edit.",
required=True,
)
parser.add_argument(
"--hparams_fname",
type=str,
default="gpt2-xl.json",
help="Name of hyperparameters file, located in the hparams/<alg_name> folder.",
required=True,
)
parser.add_argument(
"--ds_name",
choices=["cf", "zsre"],
default="cf",
help="Dataset to perform evaluations on. Either CounterFact (cf) or zsRE (zsre).",
)
parser.add_argument(
"--continue_from_run",
type=str,
default=None,
help="If continuing from previous run, set to run_id. Otherwise, leave as None.",
)
parser.add_argument(
"--dataset_size_limit",
type=int,
default=10000,
help="Truncate CounterFact to first n records.",
)
parser.add_argument(
"--skip_generation_tests",
dest="skip_generation_tests",
action="store_true",
help="Only run fast probability-based tests without slow generation tests. "
"Useful for quick debugging and hyperparameter sweeps.",
)
parser.add_argument(
"--conserve_memory",
dest="conserve_memory",
action="store_true",
help="Reduce memory usage during evaluation at the cost of a minor slowdown. "
"Backs up model weights on CPU instead of GPU.",
)
parser.set_defaults(skip_generation_tests=False, conserve_memory=False)
args = parser.parse_args()
main(
args.alg_name,
args.model_name,
args.hparams_fname,
args.ds_name,
args.dataset_size_limit,
args.continue_from_run,
args.skip_generation_tests,
args.conserve_memory,
dir_name=args.alg_name,
)