-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathfinetuning_vit_for_image_classification.py
446 lines (376 loc) · 14.6 KB
/
finetuning_vit_for_image_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
# %%
!pip install transformers evaluate datasets
# %%
import requests
import torch
from PIL import Image
from transformers import *
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
# the model name
model_name = "google/vit-base-patch16-224"
# load the image processor
image_processor = ViTImageProcessor.from_pretrained(model_name)
# loading the pre-trained model
model = ViTForImageClassification.from_pretrained(model_name).to(device)
# %%
import urllib.parse as parse
import os
# a function to determine whether a string is a URL or not
def is_url(string):
try:
result = parse.urlparse(string)
return all([result.scheme, result.netloc, result.path])
except:
return False
# a function to load an image
def load_image(image_path):
if is_url(image_path):
return Image.open(requests.get(image_path, stream=True).raw)
elif os.path.exists(image_path):
return Image.open(image_path)
# %%
def get_prediction(model, url_or_path):
# load the image
img = load_image(url_or_path)
# preprocessing the image
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"].to(device)
# perform inference
output = model(pixel_values)
# get the label id and return the class name
return model.config.id2label[int(output.logits.softmax(dim=1).argmax())]
# %%
get_prediction(model, "http://images.cocodataset.org/test-stuff2017/000000000128.jpg")
# %% [markdown]
# # Loading our Dataset
# %%
from datasets import load_dataset
# download & load the dataset
ds = load_dataset("food101")
# %% [markdown]
# ## Loading a Custom Dataset using `ImageFolder`
# Run the three below cells to load a custom dataset (that's not in the Hub) using `ImageFolder`
# %%
import requests
from tqdm import tqdm
def get_file(url):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
filename = None
content_disposition = response.headers.get('content-disposition')
if content_disposition:
parts = content_disposition.split(';')
for part in parts:
if 'filename' in part:
filename = part.split('=')[1].strip('"')
if not filename:
filename = os.path.basename(url)
block_size = 1024 # 1 Kibibyte
tqdm_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
tqdm_bar.update(len(data))
file.write(data)
tqdm_bar.close()
print(f"Downloaded {filename} ({total_size} bytes)")
return filename
# %%
import zipfile
import os
def download_and_extract_dataset():
# dataset from https://github.com/udacity/dermatologist-ai
# 5.3GB
train_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/train.zip"
# 824.5MB
valid_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/valid.zip"
# 5.1GB
test_url = "https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/skin-cancer/test.zip"
for i, download_link in enumerate([valid_url, train_url, test_url]):
data_dir = get_file(download_link)
print("Extracting", download_link)
with zipfile.ZipFile(data_dir, "r") as z:
z.extractall("data")
# remove the temp file
os.remove(data_dir)
# comment the below line if you already downloaded the dataset
download_and_extract_dataset()
# %%
from datasets import load_dataset
# load the custom dataset
ds = load_dataset("imagefolder", data_dir="data")
# %% [markdown]
# # Exploring the Data
# %%
ds
# %%
labels = ds["train"].features["label"]
labels
# %%
labels.int2str(ds["train"][532]["label"])
# %%
import random
import matplotlib.pyplot as plt
def show_image_grid(dataset, split, grid_size=(4,4)):
# Select random images from the given split
indices = random.sample(range(len(dataset[split])), grid_size[0]*grid_size[1])
images = [dataset[split][i]["image"] for i in indices]
labels = [dataset[split][i]["label"] for i in indices]
# Display the images in a grid
fig, axes = plt.subplots(nrows=grid_size[0], ncols=grid_size[1], figsize=(8,8))
for i, ax in enumerate(axes.flat):
ax.imshow(images[i])
ax.axis('off')
ax.set_title(ds["train"].features["label"].int2str(labels[i]))
plt.show()
# %%
show_image_grid(ds, "train")
# %% [markdown]
# # Preprocessing the Data
# %%
def transform(examples):
# convert all images to RGB format, then preprocessing it
# using our image processor
inputs = image_processor([img.convert("RGB") for img in examples["image"]], return_tensors="pt")
# we also shouldn't forget about the labels
inputs["labels"] = examples["label"]
return inputs
# %%
# use the with_transform() method to apply the transform to the dataset on the fly during training
dataset = ds.with_transform(transform)
# %%
for item in dataset["train"]:
print(item["pixel_values"].shape)
print(item["labels"])
break
# %%
# extract the labels for our dataset
labels = ds["train"].features["label"].names
labels
# %%
import torch
def collate_fn(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.tensor([x["labels"] for x in batch]),
}
# %% [markdown]
# # Defining the Metrics
# %%
from evaluate import load
import numpy as np
# load the accuracy and f1 metrics from the evaluate module
accuracy = load("accuracy")
f1 = load("f1")
def compute_metrics(eval_pred):
# compute the accuracy and f1 scores & return them
accuracy_score = accuracy.compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids)
f1_score = f1.compute(predictions=np.argmax(eval_pred.predictions, axis=1), references=eval_pred.label_ids, average="macro")
return {**accuracy_score, **f1_score}
# %% [markdown]
# # Training the Model
# %%
# load the ViT model
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)},
ignore_mismatched_sizes=True,
)
# %%
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./vit-base-food", # output directory
# output_dir="./vit-base-skin-cancer",
per_device_train_batch_size=32, # batch size per device during training
evaluation_strategy="steps", # evaluation strategy to adopt during training
num_train_epochs=3, # total number of training epochs
# fp16=True, # use mixed precision
save_steps=1000, # number of update steps before saving checkpoint
eval_steps=1000, # number of update steps before evaluating
logging_steps=1000, # number of update steps before logging
# save_steps=50,
# eval_steps=50,
# logging_steps=50,
save_total_limit=2, # limit the total amount of checkpoints on disk
remove_unused_columns=False, # remove unused columns from the dataset
push_to_hub=False, # do not push the model to the hub
report_to='tensorboard', # report metrics to tensorboard
load_best_model_at_end=True, # load the best model at the end of training
)
# %%
from transformers import Trainer
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
data_collator=collate_fn, # the data collator that will be used for batching
compute_metrics=compute_metrics, # the metrics function that will be used for evaluation
train_dataset=dataset["train"], # training dataset
eval_dataset=dataset["validation"], # evaluation dataset
tokenizer=image_processor, # the processor that will be used for preprocessing the images
)
# %%
# start training
trainer.train()
# %%
# trainer.evaluate(dataset["test"])
trainer.evaluate()
# %%
# start tensorboard
# %load_ext tensorboard
%reload_ext tensorboard
%tensorboard --logdir ./vit-base-food/runs
# %% [markdown]
# ## Alternatively: Training using PyTorch Loop
# Run the two below cells to fine-tune using a regular PyTorch loop if you want.
# %%
# Training loop
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
from torch.utils.data import DataLoader
batch_size = 32
train_dataset_loader = DataLoader(dataset["train"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(dataset["validation"], collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
# define the optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)
log_dir = "./image-classification/tensorboard"
summary_writer = SummaryWriter(log_dir=log_dir)
num_epochs = 3
model = model.to(device)
# print some statistics before training
# number of training steps
n_train_steps = num_epochs * len(train_dataset_loader)
# number of validation steps
n_valid_steps = len(valid_dataset_loader)
# current training step
current_step = 0
# logging, eval & save steps
save_steps = 1000
def compute_metrics(eval_pred):
accuracy_score = accuracy.compute(predictions=eval_pred.predictions, references=eval_pred.label_ids)
f1_score = f1.compute(predictions=eval_pred.predictions, references=eval_pred.label_ids, average="macro")
return {**accuracy_score, **f1_score}
# %%
for epoch in range(num_epochs):
# set the model to training mode
model.train()
# initialize the training loss
train_loss = 0
# initialize the progress bar
progress_bar = tqdm(range(current_step, n_train_steps), "Training", dynamic_ncols=True, ncols=80)
for batch in train_dataset_loader:
if (current_step+1) % save_steps == 0:
### evaluation code ###
# evaluate on the validation set
# if the current step is a multiple of the save steps
print()
print(f"Validation at step {current_step}...")
print()
# set the model to evaluation mode
model.eval()
# initialize our lists that store the predictions and the labels
predictions, labels = [], []
# initialize the validation loss
valid_loss = 0
for batch in valid_dataset_loader:
# get the batch
pixel_values = batch["pixel_values"].to(device)
label_ids = batch["labels"].to(device)
# forward pass
outputs = model(pixel_values=pixel_values, labels=label_ids)
# get the loss
loss = outputs.loss
valid_loss += loss.item()
# free the GPU memory
logits = outputs.logits.detach().cpu()
# add the predictions to the list
predictions.extend(logits.argmax(dim=-1).tolist())
# add the labels to the list
labels.extend(label_ids.tolist())
# make the EvalPrediction object that the compute_metrics function expects
eval_prediction = EvalPrediction(predictions=predictions, label_ids=labels)
# compute the metrics
metrics = compute_metrics(eval_prediction)
# print the stats
print()
print(f"Epoch: {epoch}, Step: {current_step}, Train Loss: {train_loss / save_steps:.4f}, " +
f"Valid Loss: {valid_loss / n_valid_steps:.4f}, Accuracy: {metrics['accuracy']}, " +
f"F1 Score: {metrics['f1']}")
print()
# log the metrics
summary_writer.add_scalar("valid_loss", valid_loss / n_valid_steps, global_step=current_step)
summary_writer.add_scalar("accuracy", metrics["accuracy"], global_step=current_step)
summary_writer.add_scalar("f1", metrics["f1"], global_step=current_step)
# save the model
model.save_pretrained(f"./vit-base-food/checkpoint-{current_step}")
image_processor.save_pretrained(f"./vit-base-food/checkpoint-{current_step}")
# get the model back to train mode
model.train()
# reset the train and valid loss
train_loss, valid_loss = 0, 0
### training code below ###
# get the batch & convert to tensor
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
# forward pass
outputs = model(pixel_values=pixel_values, labels=labels)
# get the loss
loss = outputs.loss
# backward pass
loss.backward()
# update the weights
optimizer.step()
# zero the gradients
optimizer.zero_grad()
# log the loss
loss_v = loss.item()
train_loss += loss_v
# increment the step
current_step += 1
progress_bar.update(1)
# log the training loss
summary_writer.add_scalar("train_loss", loss_v, global_step=current_step)
# %% [markdown]
# # Performing Inference
# %%
# load the best model, change the checkpoint number to the best checkpoint
# if the last checkpoint is the best, then ignore this cell
best_checkpoint = 7000
# best_checkpoint = 150
model = ViTForImageClassification.from_pretrained(f"./vit-base-food/checkpoint-{best_checkpoint}").to(device)
# model = ViTForImageClassification.from_pretrained(f"./vit-base-skin-cancer/checkpoint-{best_checkpoint}").to(device)
# %%
get_prediction(model, "https://images.pexels.com/photos/858496/pexels-photo-858496.jpeg?auto=compress&cs=tinysrgb&w=600&lazy=load")
# %%
def get_prediction_probs(model, url_or_path, num_classes=3):
# load the image
img = load_image(url_or_path)
# preprocessing the image
pixel_values = image_processor(img, return_tensors="pt")["pixel_values"].to(device)
# perform inference
output = model(pixel_values)
# get the top k classes and probabilities
probs, indices = torch.topk(output.logits.softmax(dim=1), k=num_classes)
# get the class labels
id2label = model.config.id2label
classes = [id2label[idx.item()] for idx in indices[0]]
# convert the probabilities to a list
probs = probs.squeeze().tolist()
# create a dictionary with the class names and probabilities
results = dict(zip(classes, probs))
return results
# %%
# example 1
get_prediction_probs(model, "https://images.pexels.com/photos/406152/pexels-photo-406152.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 2
get_prediction_probs(model, "https://images.pexels.com/photos/920220/pexels-photo-920220.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 3
get_prediction_probs(model, "https://images.pexels.com/photos/3338681/pexels-photo-3338681.jpeg?auto=compress&cs=tinysrgb&w=600")
# %%
# example 4
get_prediction_probs(model, "https://images.pexels.com/photos/806457/pexels-photo-806457.jpeg?auto=compress&cs=tinysrgb&w=600", num_classes=10)
# %%
get_prediction_probs(model, "https://images.pexels.com/photos/1624487/pexels-photo-1624487.jpeg?auto=compress&cs=tinysrgb&w=600")