-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathmain.py
262 lines (222 loc) · 10.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import argparse
import os
import wget
import torch
import clip
import os
import json
import operator
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA
from utils import get_model_from_sd, test_model_on_dataset
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-location",
type=str,
default=os.path.expanduser('~/data'),
help="The root directory for the datasets.",
)
parser.add_argument(
"--model-location",
type=str,
default=os.path.expanduser('~/ssd/checkpoints/soups'),
help="Where to download the models.",
)
parser.add_argument(
"--download-models", action="store_true", default=False,
)
parser.add_argument(
"--eval-individual-models", action="store_true", default=False,
)
parser.add_argument(
"--uniform-soup", action="store_true", default=False,
)
parser.add_argument(
"--greedy-soup", action="store_true", default=False,
)
parser.add_argument(
"--plot", action="store_true", default=False,
)
parser.add_argument(
"--batch-size",
type=int,
default=256,
)
parser.add_argument(
"--workers",
type=int,
default=8,
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
NUM_MODELS = 72
INDIVIDUAL_MODEL_RESULTS_FILE = 'individual_model_results.jsonl'
UNIFORM_SOUP_RESULTS_FILE = 'uniform_soup_results.jsonl'
GREEDY_SOUP_RESULTS_FILE = 'greedy_soup_results.jsonl'
# Step 1: Download models.
if args.download_models:
if not os.path.exists(args.model_location):
os.mkdir(args.model_location)
for i in range(NUM_MODELS):
print(f'\nDownloading model {i} of {NUM_MODELS - 1}')
wget.download(
f'https://github.com/mlfoundations/model-soups/releases/download/v0.0.2/model_{i}.pt',
out=args.model_location
)
model_paths = [os.path.join(args.model_location, f'model_{i}.pt') for i in range(NUM_MODELS)]
# Step 2: Evaluate individual models.
if args.eval_individual_models or args.uniform_soup or args.greedy_soup:
base_model, preprocess = clip.load('ViT-B/32', 'cpu', jit=False)
if args.eval_individual_models:
if os.path.exists(INDIVIDUAL_MODEL_RESULTS_FILE):
os.remove(INDIVIDUAL_MODEL_RESULTS_FILE)
for j, model_path in enumerate(model_paths):
assert os.path.exists(model_path)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model = get_model_from_sd(state_dict, base_model)
results = {'model_name' : f'model_{j}'}
# Note: ImageNet2p is the held-out minival set from ImageNet train that we use.
# It is called 2p for 2 percent of ImageNet, or 26k images.
# See utils on how this dataset is handled slightly differently.
for dataset_cls in [ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA]:
print(f'Evaluating model {j} of {NUM_MODELS - 1} on {dataset_cls.__name__}.')
dataset = dataset_cls(preprocess, args.data_location, args.batch_size, args.workers)
accuracy = test_model_on_dataset(model, dataset)
results[dataset_cls.__name__] = accuracy
print(accuracy)
with open(INDIVIDUAL_MODEL_RESULTS_FILE, 'a+') as f:
f.write(json.dumps(results) + '\n')
# Step 3: Uniform Soup.
if args.uniform_soup:
if os.path.exists(UNIFORM_SOUP_RESULTS_FILE):
os.remove(UNIFORM_SOUP_RESULTS_FILE)
# create the uniform soup sequentially to not overload memory
for j, model_path in enumerate(model_paths):
print(f'Adding model {j} of {NUM_MODELS - 1} to uniform soup.')
assert os.path.exists(model_path)
state_dict = torch.load(model_path)
if j == 0:
uniform_soup = {k : v * (1./NUM_MODELS) for k, v in state_dict.items()}
else:
uniform_soup = {k : v * (1./NUM_MODELS) + uniform_soup[k] for k, v in state_dict.items()}
model = get_model_from_sd(uniform_soup, base_model)
results = {'model_name' : f'uniform_soup'}
for dataset_cls in [ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA]:
print(f'Evaluating on {dataset_cls.__name__}.')
dataset = dataset_cls(preprocess, args.data_location, args.batch_size, args.workers)
accuracy = test_model_on_dataset(model, dataset)
results[dataset_cls.__name__] = accuracy
print(accuracy)
with open(UNIFORM_SOUP_RESULTS_FILE, 'a+') as f:
f.write(json.dumps(results) + '\n')
# Step 4: Greedy Soup.
if args.greedy_soup:
if os.path.exists(GREEDY_SOUP_RESULTS_FILE):
os.remove(GREEDY_SOUP_RESULTS_FILE)
# Sort models by decreasing accuracy on the held-out validation set ImageNet2p
# (We call the held out-val set ImageNet2p because it is 2 percent of ImageNet train)
individual_model_db = pd.read_json(INDIVIDUAL_MODEL_RESULTS_FILE, lines=True)
individual_model_val_accs = {}
for _, row in individual_model_db.iterrows():
individual_model_val_accs[row['model_name']] = row['ImageNet2p']
individual_model_val_accs = sorted(individual_model_val_accs.items(), key=operator.itemgetter(1))
individual_model_val_accs.reverse()
sorted_models = [x[0] for x in individual_model_val_accs]
# Start the soup by using the first ingredient.
greedy_soup_ingredients = [sorted_models[0]]
greedy_soup_params = torch.load(os.path.join(args.model_location, f'{sorted_models[0]}.pt'))
best_val_acc_so_far = individual_model_val_accs[0][1]
held_out_val_set = ImageNet2p(preprocess, args.data_location, args.batch_size, args.workers)
# Now, iterate through all models and consider adding them to the greedy soup.
for i in range(1, NUM_MODELS):
print(f'Testing model {i} of {NUM_MODELS}')
# Get the potential greedy soup, which consists of the greedy soup with the new model added.
new_ingredient_params = torch.load(os.path.join(args.model_location, f'{sorted_models[i]}.pt'))
num_ingredients = len(greedy_soup_ingredients)
potential_greedy_soup_params = {
k : greedy_soup_params[k].clone() * (num_ingredients / (num_ingredients + 1.)) +
new_ingredient_params[k].clone() * (1. / (num_ingredients + 1))
for k in new_ingredient_params
}
# Run the potential greedy soup on the held-out val set.
model = get_model_from_sd(potential_greedy_soup_params, base_model)
held_out_val_accuracy = test_model_on_dataset(model, held_out_val_set)
# If accuracy on the held-out val set increases, add the new model to the greedy soup.
print(f'Potential greedy soup val acc {held_out_val_accuracy}, best so far {best_val_acc_so_far}.')
if held_out_val_accuracy > best_val_acc_so_far:
greedy_soup_ingredients.append(sorted_models[i])
best_val_acc_so_far = held_out_val_accuracy
greedy_soup_params = potential_greedy_soup_params
print(f'Adding to soup. New soup is {greedy_soup_ingredients}')
# Finally, evaluate the greedy soup.
model = get_model_from_sd(greedy_soup_params, base_model)
results = {'model_name' : f'greedy_soup'}
for dataset_cls in [ImageNet2p, ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA]:
print(f'Evaluating on {dataset_cls.__name__}.')
dataset = dataset_cls(preprocess, args.data_location, args.batch_size, args.workers)
accuracy = test_model_on_dataset(model, dataset)
results[dataset_cls.__name__] = accuracy
print(accuracy)
with open(GREEDY_SOUP_RESULTS_FILE, 'a+') as f:
f.write(json.dumps(results) + '\n')
# Step 5: Plot.
if args.plot:
individual_model_db = pd.read_json(INDIVIDUAL_MODEL_RESULTS_FILE, lines=True)
individual_model_db['OOD'] = 1./5 * (individual_model_db['ImageNetV2'] +
individual_model_db['ImageNetR'] + individual_model_db['ImageNetSketch'] +
individual_model_db['ObjectNet'] + individual_model_db['ImageNetA'])
uniform_soup_db = pd.read_json(UNIFORM_SOUP_RESULTS_FILE, lines=True)
uniform_soup_db['OOD'] = 1./5 * (uniform_soup_db['ImageNetV2'] +
uniform_soup_db['ImageNetR'] + uniform_soup_db['ImageNetSketch'] +
uniform_soup_db['ObjectNet'] + uniform_soup_db['ImageNetA'])
greedy_soup_db = pd.read_json(GREEDY_SOUP_RESULTS_FILE, lines=True)
greedy_soup_db['OOD'] = 1./5 * (greedy_soup_db['ImageNetV2'] +
greedy_soup_db['ImageNetR'] + greedy_soup_db['ImageNetSketch'] +
greedy_soup_db['ObjectNet'] + greedy_soup_db['ImageNetA'])
fig = plt.figure(constrained_layout=True, figsize=(8, 6))
ax = fig.subplots()
ax.scatter(
greedy_soup_db['ImageNet'],
greedy_soup_db['OOD'],
marker='*',
color='C4',
s=400,
label='Greedy Soup',
zorder=10
)
ax.scatter(
uniform_soup_db['ImageNet'],
uniform_soup_db['OOD'],
marker='o',
color='C0',
s=200,
label='Uniform Soup',
zorder=10
)
ax.scatter(
individual_model_db['ImageNet'].values[0],
individual_model_db['OOD'].values[0],
marker='h',
color='slategray',
s=150,
label='Initialization (LP)',
zorder=10
)
ax.scatter(
individual_model_db['ImageNet'].values[1:],
individual_model_db['OOD'].values[1:],
marker='d',
color='C2',
s=130,
label='Various hyperparameters',
zorder=10
)
ax.set_ylabel('Avg. accuracy on 5 distribution shifts', fontsize=16)
ax.set_xlabel('ImageNet Accuracy (top-1%)', fontsize=16)
ax.grid()
ax.legend(fontsize=13)
plt.savefig('figure.png', bbox_inches='tight')