Skip to content

Commit 7f93f4c

Browse files
author
shruthi.gowda
committed
inbiased code
1 parent bf484f9 commit 7f93f4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+7322
-0
lines changed

README.md

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
2+
3+
This is the official code for CoLLAs 2022 paper, "InBiaseD: Inductive Bias Distillation to Improve Generalization and Robustness through Shape-awareness" by Shruthi Gowda, Elahe Arani and Bahram Zonooz.
4+
5+
## Requirements
6+
- python==3.8.0
7+
- torch==1.10.0
8+
- torchvision==0.8.0
9+
10+
## Methodology
11+
```
12+
Network : Resnet18
13+
```
14+
####InBiaseD framework
15+
16+
There are 2 networks:
17+
InBiaseD-network receiving RGB images and
18+
Shape-network receiving shape images. InBiaseD is used for inference.
19+
20+
![image info](./src/method.png)
21+
22+
## Datasets and Setup
23+
24+
![image info](./src/dataset.png)
25+
26+
The learning rate is set to $0.1$ (except for C-MNIST where it is $0.01$). SGD optimizer is used with a momentum of 0.9 and a weight decay of 1e-4. The same settings as for baselines are used for training InBiaseD. We apply random crop and random horizontal flip as the augmentations for all training.
27+
Resnet-18* refers to the CIFAR-version in which the first convolutional layer has 3x3 kernel and the maxpool operation is removed.
28+
29+
![image info](./src/setup.png)
30+
31+
## Running
32+
33+
####Train Baseline
34+
35+
```
36+
python train_normal.py
37+
--exp_identifier train_tinyimagenet_baseline
38+
--model_architecture ResNet18
39+
--dataset tinyimagenet
40+
--data_path <path to tinyimagenet dataset>
41+
--lr 0.1
42+
--weight_decay 0.0005
43+
--batch_size 128
44+
--epochs 250
45+
--cuda
46+
--test_eval_freq 100
47+
--train_eval_freq 100
48+
--seeds 0 10 20
49+
--ft_prior std
50+
--scheduler cosine
51+
--output_dir /tinyimagenet_baseline
52+
```
53+
####Train InBiaseD
54+
55+
```
56+
python train_inbiased.py
57+
--exp_identifier train_tinyimagenet_inbiased
58+
--model1_architecture ResNet18
59+
--model2_architecture ResNet18
60+
--dataset tinyimagenet
61+
--data_path <path to tinyimagenet dataset>
62+
--lr 0.1
63+
--weight_decay 0.0005
64+
--batch_size 128
65+
--epochs 250
66+
--cuda
67+
--test_eval_freq 100
68+
--train_eval_freq 100
69+
--seeds 0 10 20
70+
--ft_prior sobel
71+
--loss_type kl fitnet
72+
--loss_wt_kl1 1
73+
--loss_wt_kl2 1
74+
--loss_wt_at1 1
75+
--loss_wt_at2 5
76+
--scheduler cosine
77+
--output_dir /tinyimagenet_inbiased
78+
```
79+
80+
###Test
81+
####For evaluation only one network is used - the InBiaseD network (the first network)
82+
```
83+
python test.py
84+
--dataset
85+
tinyimagenet
86+
--data_path
87+
<path to tinyimagenet dataset>
88+
--model_path
89+
/tinyimagenet_inbiased/final_model1.pth
90+
--cuda
91+
--output_dir
92+
/tinyimagenet_inbiased/results_inbiased
93+
```
94+
95+
####For results of shapeNet (second network)
96+
```
97+
python test.py
98+
--dataset
99+
tinyimagenet
100+
--data_path
101+
<path to tinyimagenet dataset>
102+
--model_path
103+
/tinyimagenet_inbiased/final_model2.pth
104+
--cuda
105+
--output_dir
106+
/tinyimagenet_inbiased/results_shapenet
107+
```
108+
109+
####For Ensemble of InBiaseD Networks:
110+
111+
For evaluation only one network is used - the InBiaseD network (the first network)
112+
```
113+
python ensemble.py
114+
--dataset
115+
tinyimagenet
116+
--data_path
117+
<path to tinyimagenet dataset>
118+
--model_path1
119+
/tinyimagenet_inbiased/final_model1.pth
120+
--model_path2
121+
/tinyimagenet_inbiased/final_model2.pth
122+
--cuda
123+
--output_dir
124+
/tinyimagenet_inbiased/results_ensemble
125+
```
126+
127+
##Cite Our Work
128+
129+
## License
130+
131+
This project is licensed under the terms of the MIT license.
132+

__init__.py

Whitespace-only changes.

data/__init__.py

Whitespace-only changes.

data/data_eval.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
from torchvision import datasets, transforms
3+
4+
# def get_test_dataset(dataset_choice, dataset_path='dataset', remove_norm=False):
5+
#
6+
# if dataset_choice == 'imagenet_r':
7+
# normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
8+
# std=[0.2302, 0.2265, 0.2262])
9+
#
10+
# dataset = ImageFilelist(
11+
# root=dataset_path,
12+
# flist=os.path.join(dataset_path, "annotations.txt"),
13+
# transform=transforms.Compose([
14+
# transforms.Resize(64),
15+
# transforms.CenterCrop(56),
16+
# transforms.ToTensor(),
17+
# normalize,
18+
# ])
19+
# )
20+
#
21+
# elif dataset_choice == 'imagenet_blurry':
22+
# normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
23+
# std=[0.2302, 0.2265, 0.2262])
24+
#
25+
# dataset = ImageFilelist(
26+
# root=dataset_path,
27+
# flist=os.path.join(dataset_path, "annotations.txt"),
28+
# transform=transforms.Compose([
29+
# transforms.Resize(64),
30+
# transforms.CenterCrop(56),
31+
# transforms.ToTensor(),
32+
# normalize,
33+
# ])
34+
# )
35+
#
36+
# elif dataset_choice == 'imagenet_a':
37+
# normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
38+
# std=[0.2302, 0.2265, 0.2262])
39+
#
40+
# dataset = ImageFilelist(
41+
# root=dataset_path,
42+
# flist=os.path.join(dataset_path, "annotations.txt"),
43+
# transform=transforms.Compose([
44+
# transforms.Resize(64),
45+
# transforms.CenterCrop(56),
46+
# transforms.ToTensor(),
47+
# normalize,
48+
# ]),
49+
# sep=','
50+
# )
51+
#
52+
# elif dataset_choice == 'celeba':
53+
# celeba_dataset = CelebA(dataset_path)
54+
# dataset = celeba_dataset.get_dataset(split='test')
55+
#
56+
# elif dataset_choice == 'stl_tinted':
57+
# STL_TEST_TRANSFORMS = transforms.Compose([
58+
# transforms.ToPILImage(),
59+
# # transforms.CenterCrop(96),
60+
# transforms.ToTensor(),
61+
# transforms.Normalize(
62+
# (0.4192, 0.4124, 0.3804),
63+
# (0.2714, 0.2679, 0.2771)
64+
# )
65+
# ])
66+
#
67+
# dataset = STLTint(dataset_path, 'test', STL_TEST_TRANSFORMS)
68+
#
69+
# elif dataset_choice == 'domain_net_real':
70+
#
71+
# transform = transforms.Compose([
72+
# transforms.Resize((224, 224)),
73+
# transforms.ToTensor(),
74+
# transforms.Normalize(
75+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
76+
# ])
77+
#
78+
# dataset = ImageFilelist(
79+
# root=dataset_path,
80+
# flist=os.path.join(dataset_path, "real_test.txt"),
81+
# transform=transform
82+
# )
83+
#
84+
# elif dataset_choice == 'domain_net_clipart':
85+
#
86+
# transform = transforms.Compose([
87+
# transforms.Resize((224, 224)),
88+
# transforms.ToTensor(),
89+
# transforms.Normalize(
90+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
91+
# ])
92+
#
93+
# dataset = ImageFilelist(
94+
# root=dataset_path,
95+
# flist=os.path.join(dataset_path, "clipart_test.txt"),
96+
# transform=transform
97+
# )
98+
#
99+
# elif dataset_choice == 'domain_net_infograph':
100+
#
101+
# transform = transforms.Compose([
102+
# transforms.Resize((224, 224)),
103+
# transforms.ToTensor(),
104+
# transforms.Normalize(
105+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
106+
# ])
107+
#
108+
# dataset = ImageFilelist(
109+
# root=dataset_path,
110+
# flist=os.path.join(dataset_path, "infograph_test.txt"),
111+
# transform=transform
112+
# )
113+
#
114+
# elif dataset_choice == 'domain_net_painting':
115+
#
116+
# transform = transforms.Compose([
117+
# transforms.Resize((224, 224)),
118+
# transforms.ToTensor(),
119+
# transforms.Normalize(
120+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121+
# ])
122+
#
123+
# dataset = ImageFilelist(
124+
# root=dataset_path,
125+
# flist=os.path.join(dataset_path, "painting_test.txt"),
126+
# transform=transform
127+
# )
128+
#
129+
# elif dataset_choice == 'domain_net_sketch':
130+
#
131+
# transform = transforms.Compose([
132+
# transforms.Resize((224, 224)),
133+
# transforms.ToTensor(),
134+
# transforms.Normalize(
135+
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
136+
# ])
137+
#
138+
# dataset = ImageFilelist(
139+
# root=dataset_path,
140+
# flist=os.path.join(dataset_path, "sketch_test.txt"),
141+
# transform=transform
142+
# )
143+
#
144+
#
145+
# else:
146+
# raise ValueError(f'{dataset_choice} not supported')
147+
#
148+
# return dataset

0 commit comments

Comments
 (0)