Skip to content

Commit 4f4ee28

Browse files
committed
Fix the style
1 parent 8a90d0e commit 4f4ee28

26 files changed

+2655
-1787
lines changed

.style.yapf

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[style]
2+
based_on_style = google
3+
column_limit = 100
4+
dedent_closing_brackets = True

Pipfile

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cython = "*"
1717

1818
[dev-packages]
1919
ipython = "*"
20+
yapf = "*"
2021

2122
[requires]
2223
python_version = "3.5"

Pipfile.lock

+9-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

data/caltech_dataset.py

+52-30
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from torch.utils.data import Dataset
99
from geotnf.transformation import GeometricTnf
1010

11+
1112
class CaltechDataset(Dataset):
12-
1313
"""
1414
1515
Caltech-101 image pair dataset
@@ -23,66 +23,88 @@ class CaltechDataset(Dataset):
2323
2424
"""
2525

26-
def __init__(self, csv_file, dataset_path,output_size=(240,240),transform=None):
26+
def __init__(self, csv_file, dataset_path, output_size=(240, 240), transform=None):
2727

28-
self.category_names=['Faces','Faces_easy','Leopards','Motorbikes','accordion','airplanes','anchor','ant','barrel','bass','beaver','binocular','bonsai','brain','brontosaurus','buddha','butterfly','camera','cannon','car_side','ceiling_fan','cellphone','chair','chandelier','cougar_body','cougar_face','crab','crayfish','crocodile','crocodile_head','cup','dalmatian','dollar_bill','dolphin','dragonfly','electric_guitar','elephant','emu','euphonium','ewer','ferry','flamingo','flamingo_head','garfield','gerenuk','gramophone','grand_piano','hawksbill','headphone','hedgehog','helicopter','ibis','inline_skate','joshua_tree','kangaroo','ketch','lamp','laptop','llama','lobster','lotus','mandolin','mayfly','menorah','metronome','minaret','nautilus','octopus','okapi','pagoda','panda','pigeon','pizza','platypus','pyramid','revolver','rhino','rooster','saxophone','schooner','scissors','scorpion','sea_horse','snoopy','soccer_ball','stapler','starfish','stegosaurus','stop_sign','strawberry','sunflower','tick','trilobite','umbrella','watch','water_lilly','wheelchair','wild_cat','windsor_chair','wrench','yin_yang']
28+
self.category_names = [
29+
'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', 'accordion', 'airplanes', 'anchor',
30+
'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus',
31+
'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone',
32+
'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile',
33+
'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly',
34+
'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo',
35+
'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill',
36+
'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree',
37+
'kangaroo', 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin',
38+
'mayfly', 'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda',
39+
'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster',
40+
'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball',
41+
'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick',
42+
'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat',
43+
'windsor_chair', 'wrench', 'yin_yang'
44+
]
2945
self.out_h, self.out_w = output_size
3046
self.pairs = pd.read_csv(csv_file)
31-
self.img_A_names = self.pairs.iloc[:,0]
32-
self.img_B_names = self.pairs.iloc[:,1]
33-
self.category = self.pairs.iloc[:,2].as_matrix().astype('float')
47+
self.img_A_names = self.pairs.iloc[:, 0]
48+
self.img_B_names = self.pairs.iloc[:, 1]
49+
self.category = self.pairs.iloc[:, 2].as_matrix().astype('float')
3450
self.annot_A_str = self.pairs.iloc[:, 3:5]
3551
self.annot_B_str = self.pairs.iloc[:, 5:]
36-
self.dataset_path = dataset_path
52+
self.dataset_path = dataset_path
3753
self.transform = transform
3854
# no cuda as dataset is called from CPU threads in dataloader and produces confilct
39-
self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False)
40-
55+
self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda=False)
56+
4157
def __len__(self):
4258
return len(self.pairs)
4359

4460
def __getitem__(self, idx):
4561
# get pre-processed images
46-
image_A,im_size_A = self.get_image(self.img_A_names,idx)
47-
image_B,im_size_B = self.get_image(self.img_B_names,idx)
62+
image_A, im_size_A = self.get_image(self.img_A_names, idx)
63+
image_B, im_size_B = self.get_image(self.img_B_names, idx)
4864

4965
# get pre-processed point coords
5066
annot_A = self.get_points(self.annot_A_str, idx)
5167
annot_B = self.get_points(self.annot_B_str, idx)
52-
53-
sample = {'source_image': image_A, 'target_image': image_B, 'source_im_size': im_size_A, 'target_im_size': im_size_B, 'source_polygon': annot_A, 'target_polygon': annot_B}
54-
68+
69+
sample = {
70+
'source_image': image_A,
71+
'target_image': image_B,
72+
'source_im_size': im_size_A,
73+
'target_im_size': im_size_B,
74+
'source_polygon': annot_A,
75+
'target_polygon': annot_B
76+
}
77+
5578
if self.transform:
5679
sample = self.transform(sample)
5780

5881
return sample
5982

60-
def get_image(self,img_name_list,idx):
83+
def get_image(self, img_name_list, idx):
6184
img_name = os.path.join(self.dataset_path, img_name_list[idx])
6285
image = io.imread(img_name)
63-
64-
# if grayscale convert to 3-channel image
65-
if image.ndim==2:
66-
image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3)
67-
86+
87+
# if grayscale convert to 3-channel image
88+
if image.ndim == 2:
89+
image = np.repeat(np.expand_dims(image, 2), axis=2, repeats=3)
90+
6891
# get image size
6992
im_size = np.asarray(image.shape)
70-
93+
7194
# convert to torch Variable
72-
image = np.expand_dims(image.transpose((2,0,1)),0)
95+
image = np.expand_dims(image.transpose((2, 0, 1)), 0)
7396
image = torch.Tensor(image.astype(np.float32))
74-
image_var = Variable(image,requires_grad=False)
75-
97+
image_var = Variable(image, requires_grad=False)
98+
7699
# Resize image using bilinear sampling with identity affine tnf
77100
image = self.affineTnf(image_var).data.squeeze(0)
78-
101+
79102
im_size = torch.Tensor(im_size.astype(np.float32))
80-
103+
81104
return (image, im_size)
82-
83-
def get_points(self,point_coords_list,idx):
105+
106+
def get_points(self, point_coords_list, idx):
84107
point_coords_x = point_coords_list[point_coords_list.columns[0]][idx]
85108
point_coords_y = point_coords_list[point_coords_list.columns[1]][idx]
86109

87-
return (point_coords_x,point_coords_y)
88-
110+
return (point_coords_x, point_coords_y)

data/download_datasets.py

+71-62
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,41 @@
44
import tarfile
55
import zipfile
66
import requests
7-
import sys
7+
import sys
88
import click
99

10-
def download_and_uncompress(url, dest=None, chunk_size=1024, replace="ask",
11-
label="Downloading {dest_basename} ({size:.2f}MB)"):
12-
dest = dest or "./"+url.split("/")[-1]
10+
11+
def download_and_uncompress(
12+
url,
13+
dest=None,
14+
chunk_size=1024,
15+
replace="ask",
16+
label="Downloading {dest_basename} ({size:.2f}MB)"
17+
):
18+
dest = dest or "./" + url.split("/")[-1]
1319
dest_dir = dirname(dest)
1420
if not exists(dest_dir):
15-
makedirs(dest_dir)
21+
makedirs(dest_dir)
1622
if exists(dest):
17-
if (replace is False
18-
or replace == "ask"
19-
and not click.confirm("Replace {}?".format(dest))):
23+
if (replace is False or replace == "ask" and not click.confirm("Replace {}?".format(dest))):
2024
return
2125
# download file
2226
with open(dest, "wb") as f:
23-
response = requests.get(url, stream=True)
24-
total_length = response.headers.get('content-length')
25-
26-
if total_length is None: # no content length header
27-
f.write(response.content)
28-
else:
29-
dl = 0
30-
total_length = int(total_length)
31-
for data in response.iter_content(chunk_size=4096):
32-
dl += len(data)
33-
f.write(data)
34-
done = int(50 * dl / total_length)
35-
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) )
36-
sys.stdout.write("{:.1%}".format(dl / total_length))
37-
sys.stdout.flush()
27+
response = requests.get(url, stream=True)
28+
total_length = response.headers.get('content-length')
29+
30+
if total_length is None: # no content length header
31+
f.write(response.content)
32+
else:
33+
dl = 0
34+
total_length = int(total_length)
35+
for data in response.iter_content(chunk_size=4096):
36+
dl += len(data)
37+
f.write(data)
38+
done = int(50 * dl / total_length)
39+
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50 - done)))
40+
sys.stdout.write("{:.1%}".format(dl / total_length))
41+
sys.stdout.flush()
3842
sys.stdout.write("\n")
3943
# uncompress
4044
if dest.endswith("zip"):
@@ -49,77 +53,82 @@ def download_and_uncompress(url, dest=None, chunk_size=1024, replace="ask",
4953
print("Extracting data...")
5054
file.extractall(dest_dir)
5155
file.close()
52-
56+
5357
return dest
5458

59+
5560
def download_PF_willow(dest="datasets/proposal-flow-willow"):
56-
print("Fetching PF Willow dataset ")
61+
print("Fetching PF Willow dataset ")
5762
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip"
5863
file_path = join(dest, basename(url))
59-
download_and_uncompress(url,file_path)
60-
61-
print('Downloading image pair list \n') ;
64+
download_and_uncompress(url, file_path)
65+
66+
print('Downloading image pair list \n')
6267
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf.csv"
63-
file_path = join(dest,basename(url))
64-
download_and_uncompress(url,file_path)
68+
file_path = join(dest, basename(url))
69+
download_and_uncompress(url, file_path)
70+
6571

6672
def download_PF_pascal(dest="datasets/proposal-flow-pascal"):
67-
print("Fetching PF Pascal dataset ")
73+
print("Fetching PF Pascal dataset ")
6874
url = "http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip"
6975
file_path = join(dest, basename(url))
70-
download_and_uncompress(url,file_path)
71-
72-
print('Downloading image pair list \n') ;
76+
download_and_uncompress(url, file_path)
77+
78+
print('Downloading image pair list \n')
7379
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_pf_pascal.csv"
74-
file_path = join(dest,basename(url))
75-
download_and_uncompress(url,file_path)
80+
file_path = join(dest, basename(url))
81+
download_and_uncompress(url, file_path)
7682
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/val_pairs_pf_pascal.csv"
77-
file_path = join(dest,basename(url))
78-
download_and_uncompress(url,file_path)
83+
file_path = join(dest, basename(url))
84+
download_and_uncompress(url, file_path)
85+
7986

8087
def download_pascal(dest="datasets/pascal-voc11"):
8188
print("Fetching Pascal VOC2011 dataset")
8289
url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar"
8390
file_path = join(dest, basename(url))
8491
download_and_uncompress(url, file_path)
85-
92+
93+
8694
def download_pascal_parts(dest="datasets/pascal-parts"):
8795
print("Fetching Pascal Parts dataset")
8896
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/pascal_data.tar"
8997
file_path = join(dest, basename(url))
9098
download_and_uncompress(url, file_path)
91-
99+
100+
92101
def download_caltech(dest="datasets/caltech-101"):
93102
print("Fetching Caltech-101 dataset")
94103
url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz"
95104
file_path = join(dest, basename(url))
96-
download_and_uncompress(url,file_path)
105+
download_and_uncompress(url, file_path)
97106

98107
print("Fetching Caltech-101 annotations")
99-
url="http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar"
108+
url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar"
100109
file_path = join(dest, basename(url))
101-
download_and_uncompress(url,file_path)
102-
103-
print('Renaming some annotation directories\n') ;
104-
rename(join(dest,'Annotations','Airplanes_Side_2'),join(dest,'Annotations','airplanes'))
105-
rename(join(dest,'Annotations','Faces_2'),join(dest,'Annotations','Faces'))
106-
rename(join(dest,'Annotations','Faces_3'),join(dest,'Annotations','Faces_easy'))
107-
rename(join(dest,'Annotations','Motorbikes_16'),join(dest,'Annotations','Motorbikes'))
108-
print('Done renaming\n') ;
109-
110-
print('Downloading image pair list \n') ;
111-
url='http://www.di.ens.fr/willow/research/weakalign/other_resources/test_pairs_caltech_with_category.csv'
110+
download_and_uncompress(url, file_path)
111+
112+
print('Renaming some annotation directories\n')
113+
rename(join(dest, 'Annotations', 'Airplanes_Side_2'), join(dest, 'Annotations', 'airplanes'))
114+
rename(join(dest, 'Annotations', 'Faces_2'), join(dest, 'Annotations', 'Faces'))
115+
rename(join(dest, 'Annotations', 'Faces_3'), join(dest, 'Annotations', 'Faces_easy'))
116+
rename(join(dest, 'Annotations', 'Motorbikes_16'), join(dest, 'Annotations', 'Motorbikes'))
117+
print('Done renaming\n')
118+
119+
print('Downloading image pair list \n')
120+
url = 'http://www.di.ens.fr/willow/research/weakalign/other_resources/test_pairs_caltech_with_category.csv'
112121
file_path = join(dest, basename(url))
113-
download_and_uncompress(url,file_path)
114-
122+
download_and_uncompress(url, file_path)
123+
124+
115125
def download_TSS(dest="datasets/tss"):
116-
print("Fetching TSS dataset ")
126+
print("Fetching TSS dataset ")
117127
url = "http://www.hci.iis.u-tokyo.ac.jp/datasets/data/JointCorrCoseg/TSS_CVPR2016.zip"
118128
file_path = join(dest, basename(url))
119-
download_and_uncompress(url,file_path)
120-
121-
print('Downloading image pair list \n') ;
122-
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_tss.csv"
123-
file_path = join(dest,basename(url))
124-
download_and_uncompress(url,file_path)
129+
download_and_uncompress(url, file_path)
125130

131+
print('Downloading image pair list \n')
132+
url = "http://www.di.ens.fr/willow/research/cnngeometric/other_resources/test_pairs_tss.csv"
133+
file_path = join(dest, basename(url))
134+
download_and_uncompress(url, file_path)

0 commit comments

Comments
 (0)