Skip to content

Commit 79aed12

Browse files
author
michuanhaohao
committed
add keras_reid
1 parent 1fbd0f5 commit 79aed12

File tree

3 files changed

+505
-0
lines changed

3 files changed

+505
-0
lines changed

aug.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Jun 8 10:37:16 2017
5+
6+
@author: luohao
7+
"""
8+
9+
import cv2
10+
import numpy as np
11+
import nori2 as nori
12+
from tqdm import tqdm
13+
14+
def augment(img, rng, img_shape,do_training):
15+
# img = imgproc.resize_preserve_aspect_ratio(img, image_shape)
16+
17+
if do_training:
18+
# data augmentation from fb.resnet.torch
19+
# https://github.com/facebook/fb.resnet.torch/blob/master/datasets/imagenet.lua
20+
def scale(img, size):
21+
s = size / min(img.shape[0], img.shape[1])
22+
h, w = int(round(img.shape[0] * s)), int(round(img.shape[1] * s))
23+
return cv2.resize(img, (w, h))
24+
def center_crop(img, shape):
25+
h, w = img.shape[:2]
26+
sx, sy = (w - shape[1]) // 2, (h - shape[0]) // 2
27+
img = img[sy:sy + shape[0], sx:sx + shape[1]]
28+
return img
29+
def random_sized_crop(img):
30+
NR_REPEAT = 10
31+
h, w = img.shape[:2]
32+
area = h * w
33+
ar = [3. / 4, 4. / 3]
34+
for i in range(NR_REPEAT):
35+
target_area = rng.uniform(0.08, 1.0) * area
36+
target_ar = rng.choice(ar)
37+
nw = int(round((target_area * target_ar) ** 0.5))
38+
nh = int(round((target_area / target_ar) ** 0.5))
39+
if rng.rand() < 0.5:
40+
nh, nw = nw, nh
41+
if nh <= h and nw <= w:
42+
sx, sy = rng.randint(w - nw + 1), rng.randint(h - nh + 1)
43+
img = img[sy:sy + nh, sx:sx + nw]
44+
return cv2.resize(img, image_shape[::-1])
45+
size = min(image_shape[0], image_shape[1])
46+
return center_crop(scale(img, size), image_shape)
47+
def grayscale(img):
48+
w = np.array([0.114, 0.587, 0.299]).reshape(1, 1, 3)
49+
gs = np.zeros(img.shape[:2])
50+
gs = (img * w).sum(axis=2, keepdims=True)
51+
return gs
52+
def brightness_aug(img, val):
53+
alpha = 1. + val * (rng.rand() * 2 - 1)
54+
img = img * alpha
55+
return img
56+
def contrast_aug(img, val):
57+
gs = grayscale(img)
58+
gs[:] = gs.mean()
59+
alpha = 1. + val * (rng.rand() * 2 - 1)
60+
img = img * alpha + gs * (1 - alpha)
61+
return img
62+
def saturation_aug(img, val):
63+
gs = grayscale(img)
64+
alpha = 1. + val * (rng.rand() * 2 - 1)
65+
img = img * alpha + gs * (1 - alpha)
66+
return img
67+
def color_jitter(img, brightness, contrast, saturation):
68+
augs = [(brightness_aug, brightness),
69+
(contrast_aug, contrast),
70+
(saturation_aug, saturation)]
71+
rng.shuffle(augs)
72+
for aug, val in augs:
73+
img = aug(img, val)
74+
return img
75+
def lighting(img, std):
76+
eigval = np.array([0.2175, 0.0188, 0.0045])
77+
eigvec = np.array([
78+
[-0.5836, -0.6948, 0.4203],
79+
[-0.5808, -0.0045, -0.8140],
80+
[-0.5675, 0.7192, 0.4009],
81+
])
82+
if std == 0:
83+
return img
84+
alpha = rng.randn(3) * std
85+
bgr = eigvec * alpha.reshape(1, 3) * eigval.reshape(1, 3)
86+
bgr = bgr.sum(axis=1).reshape(1, 1, 3)
87+
img = img + bgr
88+
return img
89+
def horizontal_flip(img, prob):
90+
if rng.rand() < prob:
91+
return img[:, ::-1]
92+
return img
93+
# def warp_perspective(img):
94+
# c = (
95+
# ((-50, 50), (-10, 10)),
96+
# ((-50, 50), (-10, 10)),
97+
# ((-50, 50), (-10, 10)),
98+
# ((-50, 50), (-10, 10))
99+
# )
100+
# mat = imgaug.get_random_perspective_transform_mat(
101+
# rng, c, image_shape)
102+
# return cv2.warpPerspective(img, mat, image_shape)
103+
img = color_jitter(img, brightness=0.4, contrast=0.4, saturation=0.4)
104+
img = lighting(img, 0.1)
105+
img = horizontal_flip(img, 0.5)
106+
img = np.minimum(255.0, np.maximum(0, img))
107+
# return np.rollaxis(img, 2).astype('float32')
108+
return img.astype('float32') #return h*w*3
109+
110+
def aug_nhw3(imgs):
111+
for ind in range(len(imgs)):
112+
img = imgs[ind]
113+
rng = np.random.RandomState()
114+
do_training = True
115+
imgs[ind] = augment(img, rng, (128,128),do_training)
116+
return imgs
117+
118+
def aug_n3hw(imgs):
119+
imgs = np.transpose(imgs,(0,2,3,1))
120+
imgs = aug_nhw3(imgs)
121+
imgs = np.transpose(imgs,(0,3,1,2))
122+
return imgs
123+
124+
125+
126+

reid_classification.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Fri Jun 2 16:21:54 2017
5+
6+
@author: luohao
7+
"""
8+
9+
import numpy as np
10+
11+
12+
from keras import optimizers
13+
from keras.utils import np_utils, generic_utils
14+
from keras.models import Sequential,Model
15+
from keras.layers import Dropout, Flatten, Dense,Input
16+
from keras.applications.resnet50 import ResNet50
17+
from keras.applications.imagenet_utils import preprocess_input
18+
from keras import backend as K
19+
from keras.layers.core import Lambda
20+
from sklearn.preprocessing import normalize
21+
from keras.preprocessing.image import ImageDataGenerator
22+
from keras.initializers import RandomNormal
23+
24+
25+
import numpy.linalg as la
26+
from IPython import embed
27+
28+
#欧式距离
29+
def euclidSimilar(query_ind,test_all,top_num):
30+
le = len(test_all)
31+
dis = np.zeros(le)
32+
for ind in range(le):
33+
sub = test_all[ind]-query_ind
34+
dis[ind] = la.norm(sub)
35+
ii = sorted(range(len(dis)), key=lambda k: dis[k])
36+
# print(ii[:top_num+1])
37+
return ii[1:top_num+1]
38+
39+
def euclidSimilar2(query_ind,test_all):
40+
le = len(test_all)
41+
dis = np.zeros(le)
42+
for ind in range(le):
43+
sub = test_all[ind]-query_ind
44+
dis[ind] = la.norm(sub)
45+
ii = sorted(range(len(dis)), key=lambda k: dis[k])
46+
# embed()
47+
# print(ii[:top_num+1])
48+
return ii
49+
50+
def get_top_ind(query_all,test_all,top_num):
51+
query_num = len(query_all)
52+
query_result_ind = np.zeros([query_num,top_num],np.int32)
53+
for ind in range(query_num):
54+
query_result_ind[ind]=euclidSimilar(query_all[ind],test_all,top_num)
55+
# if np.mod(ind,100)==0:
56+
# print('query_ind '+str(ind))
57+
return query_result_ind
58+
59+
def get_top_label(query_result_ind):
60+
num = len(query_result_ind)
61+
top_num = len(query_result_ind[0])
62+
query_top_label = np.zeros([num,top_num],np.int32)
63+
for ind in range(num):
64+
for ind2 in range(top_num):
65+
query_top_label[ind][ind2]= test_label[query_result_ind[ind][ind2]]
66+
# if np.mod(ind,100)==0:
67+
# print('query_label '+str(ind))
68+
# print(query_top_label[ind])
69+
return query_top_label
70+
71+
def get_top_acc(test_label,query_result_label):
72+
query_label= test_label
73+
top1 = 0
74+
top5 = 0
75+
top10 = 0
76+
for ind in range(len(query_result_label)):
77+
query_temp_label = query_result_label[ind]-query_label[ind]
78+
# print(query_temp_label)
79+
query_temp = np.where(query_temp_label==0)
80+
if len(query_temp[0] >0):
81+
if query_temp[0][0]<1:
82+
top1 = top1+1
83+
if query_temp[0][0]<5:
84+
top5 = top5+1
85+
if query_temp[0][0]<10:
86+
top10 = top10+1
87+
ind = ind +1
88+
top1 = top1/ind*1.0
89+
top5 = top5/ind*1.0
90+
top10 = top10/ind*1.0
91+
print(str(ind)+' query images')
92+
return top1,top5,top10
93+
94+
95+
def single_query(query_feature,test_feature,query_label,test_label,test_num):
96+
test_label_set = np.unique(test_label)
97+
#single_num = len(test_label_set)
98+
test_label_dict={}
99+
topp1=0
100+
topp5=0
101+
topp10=0
102+
for ind in range(len(test_label_set)):
103+
test_label_dict[test_label_set[ind]]=np.where(test_label==test_label_set[ind])
104+
for ind in range(test_num):
105+
query_int = np.random.choice(len(query_label))
106+
label = query_label[query_int]
107+
temp_int = np.random.choice(test_label_dict[label][0],1)
108+
temp_gallery_ind = temp_int
109+
for ind2 in range(len(test_label_set)):
110+
temp_label = test_label_set[ind2]
111+
if temp_label != label:
112+
temp_int = np.random.choice(test_label_dict[temp_label][0],1)
113+
temp_gallery_ind = np.append(temp_gallery_ind,temp_int)
114+
single_query_feature = query_feature[query_int]
115+
test_all_feature = test_feature[temp_gallery_ind]
116+
result_ind = euclidSimilar2(single_query_feature,test_all_feature)
117+
query_temp = result_ind.index(0)
118+
if query_temp<1:
119+
topp1 = topp1+1
120+
if query_temp<5:
121+
topp5 = topp5+1
122+
if query_temp<10:
123+
topp10 = topp10+1
124+
topp1 =topp1/test_num*1.0
125+
topp5 =topp5/test_num*1.0
126+
topp10 =topp10/test_num*1.0
127+
print('single query')
128+
print('top1: '+str(topp1)+'\n')
129+
print('top5: '+str(topp5)+'\n')
130+
print('top10: '+str(topp10)+'\n')
131+
132+
133+
"================================"
134+
135+
identity_num = 6273
136+
137+
print('loading data...')
138+
## please add your loading validation data here
139+
140+
141+
#from load_market_img import get_img
142+
#query_img,test_img,query_label,test_label=get_img()
143+
test_img =preprocess_input(test_img)
144+
query_img = preprocess_input(query_img)
145+
146+
''''''''''''''''''''''''''
147+
datagen = ImageDataGenerator(horizontal_flip=True)
148+
149+
# load pre-trained resnet50
150+
base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=Input(shape=(224,224,3)))
151+
x = base_model.output
152+
feature = Flatten(name='flatten')(x)
153+
fc1 = Dropout(0.5)(feature)
154+
preds = Dense(identity_num, activation='softmax', name='fc8', kernel_initializer=RandomNormal(mean=0.0, stddev=0.001))(fc1) #default glorot_uniform
155+
net = Model(input=base_model.input, output=preds)
156+
feature_model = Model(input=base_model.input, output=feature)
157+
158+
#training IDE model for all layers
159+
for layer in net.layers:
160+
layer.trainable = True
161+
162+
# train
163+
batch_size = 16
164+
#step 1
165+
adam = optimizers.Adam(lr=0.001)
166+
net.compile(optimizer=adam, loss='categorical_crossentropy',metric ='accuracy')
167+
168+
# your can add a pre-trained model here
169+
#net.load_weights('net_ide.h5')
170+
from load_img_data import get_train_img
171+
172+
ind = 0
173+
while(True):
174+
train_img,train_label = get_train_img(8000) #add your loading training data here
175+
train_img = preprocess_input(train_img)
176+
train_label_onehot = np_utils.to_categorical(train_label,identity_num)
177+
net.fit_generator(datagen.flow(train_img, train_label_onehot, batch_size=batch_size),
178+
steps_per_epoch=len(train_img)/batch_size, epochs=1)
179+
ind = ind+1
180+
# your can add sth here
181+
# if np.mod(ind,100) == 0:
182+

0 commit comments

Comments
 (0)