forked from thanhthu152/CMC-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
90 lines (76 loc) · 3.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import pytorch_lightning as pl
from dataset import PolypDS, train_transform, val_transform
from torch.utils.data import DataLoader
from metrics import iou_score, dice_score, Loss
from model import PVTFormerNet
# Lightning module
class Segmentor(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def _step(self, batch):
image, y_true = batch
y_pred = self.model(image)
loss = Loss()(y_pred, y_true)
dice = dice_score(y_pred, y_true)
iou = iou_score(y_pred, y_true)
return loss, dice, iou
def training_step(self, batch, batch_idx):
loss, dice, iou = self._step(batch)
metrics = {"loss": loss, "train_dice": dice, "train_iou": iou}
self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
loss, dice, iou = self._step(batch)
metrics = {"val_loss":loss, "val_dice": dice, "val_iou": iou}
self.log_dict(metrics, prog_bar=True)
return metrics
def test_step(self, batch, batch_idx):
loss, dice, iou = self._step(batch)
metrics = {"loss":loss, "test_dice": dice, "test_iou": iou}
self.log_dict(metrics, prog_bar=True)
return metrics
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=2e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max",
factor = 0.5, patience=10, verbose =True)
lr_schedulers = {"scheduler": scheduler, "monitor": "val_dice"}
return [optimizer], lr_schedulers
model = PVTFormerNet().cuda()
DATA_PATH = ''
# Dataset
train_ds = PolypDS(type = 'train', transform=train_transform)
val_ds = PolypDS(type = 'val', transform=val_transform)
test1_ds = PolypDS(type = 'test_kvasir', transform=val_transform)
test2_ds = PolypDS(type = 'test_etis', transform=val_transform)
test3_ds = PolypDS(type = 'test_cvc300', transform=val_transform)
test4_ds = PolypDS(type = 'test_clinic', transform=val_transform)
test5_ds = PolypDS(type = 'test_colon', transform=val_transform)
# DataLoader
trainloader = DataLoader(train_ds, batch_size=16, num_workers=2, shuffle=True)
valloader = DataLoader(val_ds, batch_size=4, num_workers=2, shuffle=False)
testloader1 = DataLoader(test1_ds, batch_size=1, num_workers=2, shuffle=False)
testloader2 = DataLoader(test2_ds, batch_size=1, num_workers=2, shuffle=False)
testloader3 = DataLoader(test3_ds, batch_size=1, num_workers=2, shuffle=False)
testloader4 = DataLoader(test4_ds, batch_size=4, num_workers=2, shuffle=False)
testloader5 = DataLoader(test5_ds, batch_size=1, num_workers=2, shuffle=False)
# Training config
os.makedirs('/content/weights', exist_ok = True)
check_point = pl.callbacks.model_checkpoint.ModelCheckpoint("/content/weights", filename="ckpt{val_dice:0.4f}",
monitor="val_dice", mode = "max", save_top_k =1,
verbose=True, save_weights_only=True,
auto_insert_metric_name=False,)
progress_bar = pl.callbacks.TQDMProgressBar()
PARAMS = {"benchmark": True, "enable_progress_bar" : True,"logger":True,
"callbacks" : [check_point, progress_bar],
"log_every_n_steps" :1, "num_sanity_val_steps":0, "max_epochs":200,
"precision":16,
}
trainer = pl.Trainer(**PARAMS)
segmentor = Segmentor(model=model)
# Training
trainer.fit(segmentor, trainloader, valloader)