-
Notifications
You must be signed in to change notification settings - Fork 3
/
fake_labels_generator.py
100 lines (72 loc) · 2.38 KB
/
fake_labels_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from glob import glob
from dataset import SingleDataset
from utils import dice_coef_2d
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
from torchvision import models
from torchvision.models.vgg import VGG
import torch.nn.functional as F
import numpy as np
import math
from PIL import Image
from datetime import datetime
from model import *
import torch.optim as optim
import matplotlib.pyplot as plt
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default='U_net',
help='type of model')
parser.add_argument('--model_path', type=str, default=r'',
help='path of supervised model')
parser.add_argument('--Image_dir', type=str, default=r'',
help='path of images to be labeled')
parser.add_argument('--Mask-dir', type=str, default='test_img_pre',
help='path of masks to be saved')
opt = parser.parse_args()
def cycle(iterable):
while True:
print('end')
for x in iterable:
yield x
Image_path = glob(os.path.join(opt.Image_dir, '*'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(opt.model_type)
transform = transforms.Compose([
transforms.ToTensor(),
])
if opt.model_type == 'Res_Unet':
model = CARes_Unet()
else:
print('Model Not Found')
exit(-1)
if os.path.exists(opt.model_path):
model.load_state_dict(torch.load(opt.model_path))
else:
print('Model Not Found!')
exit(-1)
model.eval()
model.to(device)
test_loss = 0
test_dice = 0
i = 1
if not(os.path.exists(opt.Mask_dir)):
os.mkdir(opt.Mask_dir)
for name in os.listdir(opt.Image_dir):
img = Image.open(os.path.join(opt.Image_dir,name)).convert('L')
img = img.resize((256,256),Image.ANTIALIAS)
img = transform(img)
img = img.view((1, img.shape[0], img.shape[1], img.shape[2]))
img = img.to(device)
output = model(img)
output = torch.argmax(output, dim=1, keepdim=True).float()
output_np = output.cpu().detach().numpy().copy()
out = output_np[0] * 255
out = (out).astype(np.uint8)
out = out[0]
out_img = Image.fromarray(out)
out_img.save(os.path.join(opt.Mask_dir,name))