-
Notifications
You must be signed in to change notification settings - Fork 0
/
func.py
119 lines (93 loc) · 3.19 KB
/
func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
def createFolder(args):
if os.path.exists(args) is False:
os.makedirs(args)
def fill_color(img, canvas):
idx_range = np.arange(0, img.shape[0])
col_range = np.arange(0, img.shape[1])
filter_idx=[]
filter_col=[]
for idx in idx_range:
for col in col_range:
filter = canvas[idx:idx+5, col:col+5]
filter_mean = filter.mean()
if filter_mean > 30:
# print(filter_mean)
filter_idx.append(idx)
filter_col.append(col)
for i in range(len(filter_idx)):
canvas[filter_idx[i]:filter_idx[i]+6, filter_col[i]:filter_col[i]+6] = 255
return canvas
def img_segmentation(img):
mask = np.zeros(shape=img.shape, dtype=np.uint8)
img1 = img[:51, :51] # top-left
img2 = img[236 - 34:, :41 + 35] # bottom-left
img3 = img[:28 + 27, 225 - 25:] # top-right
img4 = img[215 - 25:, 200 - 25:] # bottom-right
mask[:51, :51] = img1.astype(np.float32)
mask[236 - 34:, :41 + 35] = img2.astype(np.float32)
mask[:28 + 27, 225 - 25:] = img3.astype(np.float32)
mask[215 - 25:, 200 - 25:] = img4.astype(np.float32)
return mask
def fig_show(img, segmented_img, thresh, filled_mask, masked_img):
fig = plt.figure(figsize=(8, 8))
plt.subplot(1, 5, 1)
plt.imshow(img)
plt.title('original img')
plt.xticks([])
plt.yticks([])
#
plt.subplot(1, 5, 2)
plt.imshow(segmented_img)
plt.title('segmented img')
plt.xticks([])
plt.yticks([])
#
plt.subplot(1, 5, 3)
plt.imshow(thresh)
plt.title('thresh_img')
plt.xticks([])
plt.yticks([])
plt.subplot(1, 5, 4)
plt.imshow(filled_mask)
plt.title('filled_mask')
plt.xticks([])
plt.yticks([])
plt.subplot(1, 5, 5)
plt.imshow(masked_img)
plt.title('masked_img')
plt.xticks([])
plt.yticks([])
plt.show()
def main(args):
img_path = args.baseroot
save_path = args.save_dir
for file_name in os.listdir(img_path):
print(file_name)
img = cv2.imread(img_path + '{}'.format(file_name))
img = cv2.resize(img, (256,256))
h, w, _ = img.shape
img = np.array(img)
print('img.shape :', img.shape)
if args.segmentation:
img_seg = img_segmentation(img)
else:
img_seg = img
mask = np.zeros(shape=img.shape, dtype=np.uint8)
mask[np.where((img_seg > [args.threshold, args.threshold, args.threshold]).all(axis=2))] = [255, 255, 255]
thresh = mask.copy()
if args.img_fill:
fill_mask = fill_color(img, mask)
else:
fill_mask = mask
masked = img.copy()
masked[np.where((fill_mask > [230, 230, 230]).all(axis=2))] = [255, 255, 255]
if args.img_show:
fig_show(img, img_seg, thresh, fill_mask, masked)
if args.save_fig:
# Create a folder if it does`t folder
createFolder(save_path)
cv2.imwrite(save_path + '{}'.format(file_name), fill_mask)