8
8
from torch .utils .data import Dataset
9
9
from geotnf .transformation import GeometricTnf
10
10
11
+
11
12
class CaltechDataset (Dataset ):
12
-
13
13
"""
14
14
15
15
Caltech-101 image pair dataset
@@ -23,66 +23,88 @@ class CaltechDataset(Dataset):
23
23
24
24
"""
25
25
26
- def __init__ (self , csv_file , dataset_path ,output_size = (240 ,240 ),transform = None ):
26
+ def __init__ (self , csv_file , dataset_path , output_size = (240 , 240 ), transform = None ):
27
27
28
- self .category_names = ['Faces' ,'Faces_easy' ,'Leopards' ,'Motorbikes' ,'accordion' ,'airplanes' ,'anchor' ,'ant' ,'barrel' ,'bass' ,'beaver' ,'binocular' ,'bonsai' ,'brain' ,'brontosaurus' ,'buddha' ,'butterfly' ,'camera' ,'cannon' ,'car_side' ,'ceiling_fan' ,'cellphone' ,'chair' ,'chandelier' ,'cougar_body' ,'cougar_face' ,'crab' ,'crayfish' ,'crocodile' ,'crocodile_head' ,'cup' ,'dalmatian' ,'dollar_bill' ,'dolphin' ,'dragonfly' ,'electric_guitar' ,'elephant' ,'emu' ,'euphonium' ,'ewer' ,'ferry' ,'flamingo' ,'flamingo_head' ,'garfield' ,'gerenuk' ,'gramophone' ,'grand_piano' ,'hawksbill' ,'headphone' ,'hedgehog' ,'helicopter' ,'ibis' ,'inline_skate' ,'joshua_tree' ,'kangaroo' ,'ketch' ,'lamp' ,'laptop' ,'llama' ,'lobster' ,'lotus' ,'mandolin' ,'mayfly' ,'menorah' ,'metronome' ,'minaret' ,'nautilus' ,'octopus' ,'okapi' ,'pagoda' ,'panda' ,'pigeon' ,'pizza' ,'platypus' ,'pyramid' ,'revolver' ,'rhino' ,'rooster' ,'saxophone' ,'schooner' ,'scissors' ,'scorpion' ,'sea_horse' ,'snoopy' ,'soccer_ball' ,'stapler' ,'starfish' ,'stegosaurus' ,'stop_sign' ,'strawberry' ,'sunflower' ,'tick' ,'trilobite' ,'umbrella' ,'watch' ,'water_lilly' ,'wheelchair' ,'wild_cat' ,'windsor_chair' ,'wrench' ,'yin_yang' ]
28
+ self .category_names = [
29
+ 'Faces' , 'Faces_easy' , 'Leopards' , 'Motorbikes' , 'accordion' , 'airplanes' , 'anchor' ,
30
+ 'ant' , 'barrel' , 'bass' , 'beaver' , 'binocular' , 'bonsai' , 'brain' , 'brontosaurus' ,
31
+ 'buddha' , 'butterfly' , 'camera' , 'cannon' , 'car_side' , 'ceiling_fan' , 'cellphone' ,
32
+ 'chair' , 'chandelier' , 'cougar_body' , 'cougar_face' , 'crab' , 'crayfish' , 'crocodile' ,
33
+ 'crocodile_head' , 'cup' , 'dalmatian' , 'dollar_bill' , 'dolphin' , 'dragonfly' ,
34
+ 'electric_guitar' , 'elephant' , 'emu' , 'euphonium' , 'ewer' , 'ferry' , 'flamingo' ,
35
+ 'flamingo_head' , 'garfield' , 'gerenuk' , 'gramophone' , 'grand_piano' , 'hawksbill' ,
36
+ 'headphone' , 'hedgehog' , 'helicopter' , 'ibis' , 'inline_skate' , 'joshua_tree' ,
37
+ 'kangaroo' , 'ketch' , 'lamp' , 'laptop' , 'llama' , 'lobster' , 'lotus' , 'mandolin' ,
38
+ 'mayfly' , 'menorah' , 'metronome' , 'minaret' , 'nautilus' , 'octopus' , 'okapi' , 'pagoda' ,
39
+ 'panda' , 'pigeon' , 'pizza' , 'platypus' , 'pyramid' , 'revolver' , 'rhino' , 'rooster' ,
40
+ 'saxophone' , 'schooner' , 'scissors' , 'scorpion' , 'sea_horse' , 'snoopy' , 'soccer_ball' ,
41
+ 'stapler' , 'starfish' , 'stegosaurus' , 'stop_sign' , 'strawberry' , 'sunflower' , 'tick' ,
42
+ 'trilobite' , 'umbrella' , 'watch' , 'water_lilly' , 'wheelchair' , 'wild_cat' ,
43
+ 'windsor_chair' , 'wrench' , 'yin_yang'
44
+ ]
29
45
self .out_h , self .out_w = output_size
30
46
self .pairs = pd .read_csv (csv_file )
31
- self .img_A_names = self .pairs .iloc [:,0 ]
32
- self .img_B_names = self .pairs .iloc [:,1 ]
33
- self .category = self .pairs .iloc [:,2 ].as_matrix ().astype ('float' )
47
+ self .img_A_names = self .pairs .iloc [:, 0 ]
48
+ self .img_B_names = self .pairs .iloc [:, 1 ]
49
+ self .category = self .pairs .iloc [:, 2 ].as_matrix ().astype ('float' )
34
50
self .annot_A_str = self .pairs .iloc [:, 3 :5 ]
35
51
self .annot_B_str = self .pairs .iloc [:, 5 :]
36
- self .dataset_path = dataset_path
52
+ self .dataset_path = dataset_path
37
53
self .transform = transform
38
54
# no cuda as dataset is called from CPU threads in dataloader and produces confilct
39
- self .affineTnf = GeometricTnf (out_h = self .out_h , out_w = self .out_w , use_cuda = False )
40
-
55
+ self .affineTnf = GeometricTnf (out_h = self .out_h , out_w = self .out_w , use_cuda = False )
56
+
41
57
def __len__ (self ):
42
58
return len (self .pairs )
43
59
44
60
def __getitem__ (self , idx ):
45
61
# get pre-processed images
46
- image_A ,im_size_A = self .get_image (self .img_A_names ,idx )
47
- image_B ,im_size_B = self .get_image (self .img_B_names ,idx )
62
+ image_A , im_size_A = self .get_image (self .img_A_names , idx )
63
+ image_B , im_size_B = self .get_image (self .img_B_names , idx )
48
64
49
65
# get pre-processed point coords
50
66
annot_A = self .get_points (self .annot_A_str , idx )
51
67
annot_B = self .get_points (self .annot_B_str , idx )
52
-
53
- sample = {'source_image' : image_A , 'target_image' : image_B , 'source_im_size' : im_size_A , 'target_im_size' : im_size_B , 'source_polygon' : annot_A , 'target_polygon' : annot_B }
54
-
68
+
69
+ sample = {
70
+ 'source_image' : image_A ,
71
+ 'target_image' : image_B ,
72
+ 'source_im_size' : im_size_A ,
73
+ 'target_im_size' : im_size_B ,
74
+ 'source_polygon' : annot_A ,
75
+ 'target_polygon' : annot_B
76
+ }
77
+
55
78
if self .transform :
56
79
sample = self .transform (sample )
57
80
58
81
return sample
59
82
60
- def get_image (self ,img_name_list ,idx ):
83
+ def get_image (self , img_name_list , idx ):
61
84
img_name = os .path .join (self .dataset_path , img_name_list [idx ])
62
85
image = io .imread (img_name )
63
-
64
- # if grayscale convert to 3-channel image
65
- if image .ndim == 2 :
66
- image = np .repeat (np .expand_dims (image ,2 ),axis = 2 ,repeats = 3 )
67
-
86
+
87
+ # if grayscale convert to 3-channel image
88
+ if image .ndim == 2 :
89
+ image = np .repeat (np .expand_dims (image , 2 ), axis = 2 , repeats = 3 )
90
+
68
91
# get image size
69
92
im_size = np .asarray (image .shape )
70
-
93
+
71
94
# convert to torch Variable
72
- image = np .expand_dims (image .transpose ((2 ,0 , 1 )),0 )
95
+ image = np .expand_dims (image .transpose ((2 , 0 , 1 )), 0 )
73
96
image = torch .Tensor (image .astype (np .float32 ))
74
- image_var = Variable (image ,requires_grad = False )
75
-
97
+ image_var = Variable (image , requires_grad = False )
98
+
76
99
# Resize image using bilinear sampling with identity affine tnf
77
100
image = self .affineTnf (image_var ).data .squeeze (0 )
78
-
101
+
79
102
im_size = torch .Tensor (im_size .astype (np .float32 ))
80
-
103
+
81
104
return (image , im_size )
82
-
83
- def get_points (self ,point_coords_list ,idx ):
105
+
106
+ def get_points (self , point_coords_list , idx ):
84
107
point_coords_x = point_coords_list [point_coords_list .columns [0 ]][idx ]
85
108
point_coords_y = point_coords_list [point_coords_list .columns [1 ]][idx ]
86
109
87
- return (point_coords_x ,point_coords_y )
88
-
110
+ return (point_coords_x , point_coords_y )
0 commit comments