-
Notifications
You must be signed in to change notification settings - Fork 25
/
pre-process.py
89 lines (73 loc) · 3.27 KB
/
pre-process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
import zipfile
import cv2 as cv
from tqdm import tqdm
from config import img_height, img_width
def ensure_folder(folder):
if not os.path.exists(folder):
os.makedirs(folder)
def extract(usage, package, image_path, json_path):
filename = 'data/{}.zip'.format(package)
print('Extracting {}...'.format(filename))
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall('data')
if not os.path.exists('data/{}'.format(usage)):
os.makedirs('data/{}'.format(usage))
with open('data/{}/{}'.format(package, json_path)) as json_data:
data = json.load(json_data)
num_samples = len(data)
print("num_samples: " + str(num_samples))
for i in tqdm(range(num_samples)):
item = data[i]
image_name = item['image_id']
label_id = item['label_id']
src_folder = 'data/{}/{}'.format(package, image_path)
src_path = os.path.join(src_folder, image_name)
dst_folder = 'data/{}'.format(usage)
label = "%02d" % (int(label_id),)
dst_path = os.path.join(dst_folder, label)
if not os.path.exists(dst_path):
os.makedirs(dst_path)
dst_path = os.path.join(dst_path, image_name)
src_image = cv.imread(src_path)
dst_image = cv.resize(src_image, (img_height, img_width), cv.INTER_CUBIC)
cv.imwrite(dst_path, dst_image)
def extract_test(usage, package, image_path, json_path):
filename = 'data/{}.zip'.format(package)
print('Extracting {}...'.format(filename))
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall('data')
if not os.path.exists('data/{}'.format(usage)):
os.makedirs('data/{}'.format(usage))
with open('data/{}/{}'.format(package, json_path)) as json_data:
data = json.load(json_data)
num_samples = len(data)
print("num_samples: " + str(num_samples))
label_dict = dict()
for i in tqdm(range(num_samples)):
item = data[i]
image_name = item['image_id']
label_id = item['label_id']
src_folder = 'data/{}/{}'.format(package, image_path)
src_path = os.path.join(src_folder, image_name)
dst_folder = 'data/{}'.format(usage)
label = "%02d" % (int(label_id),)
label_dict[image_name] = label
dst_path = os.path.join(dst_folder, image_name)
src_image = cv.imread(src_path)
dst_image = cv.resize(src_image, (img_height, img_width), cv.INTER_CUBIC)
cv.imwrite(dst_path, dst_image)
with open('label_dict.txt', 'w') as outfile:
json.dump(label_dict, outfile, indent=4, sort_keys=True)
if __name__ == '__main__':
# parameters
ensure_folder('data')
extract('train', 'ai_challenger_scene_train_20170904', 'scene_train_images_20170904',
'scene_train_annotations_20170904.json')
extract('valid', 'ai_challenger_scene_validation_20170908', 'scene_validation_images_20170908',
'scene_validation_annotations_20170908.json')
extract_test('test_a', 'ai_challenger_scene_test_a_20180103', 'scene_test_a_images_20180103',
'scene_test_a_annotations_20180103.json')
extract_test('test_b', 'ai_challenger_scene_test_b_20180103', 'scene_test_b_images_20180103',
'scene_test_b_annotations_20180103.json')