Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ATNs for ImageNet Dataset #100

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7c48069
Added seeding and checkpoints
soham-chitnis10 Aug 29, 2021
672949a
Added seeding and checkpoints
soham-chitnis10 Aug 30, 2021
8ae0d30
Added seeding and checkpoints
soham-chitnis10 Aug 30, 2021
3e941fd
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Aug 30, 2021
c814329
Updated seed function
soham-chitnis10 Aug 30, 2021
37c3bca
Updated seed function and test
soham-chitnis10 Aug 30, 2021
c2d298f
Merge branch 'main' into main
someshsingh22 Aug 30, 2021
7f12a3d
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Aug 30, 2021
deb9b17
Restructured according to PEP8
soham-chitnis10 Aug 30, 2021
e5ed281
Restructured according to REPO_STRUCTURE
soham-chitnis10 Aug 31, 2021
e8ed281
Added test for seeding and checkpoints
soham-chitnis10 Aug 31, 2021
de34749
delete test_model.tar
soham-chitnis10 Aug 31, 2021
651bf6b
Updated checkpoints
soham-chitnis10 Aug 31, 2021
c148e05
Added tests for checkpoints
soham-chitnis10 Aug 31, 2021
35bdd6e
Merge pull request #1 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Aug 31, 2021
e95f444
Updated seeding
soham-chitnis10 Aug 31, 2021
0036ad5
Merge pull request #2 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Aug 31, 2021
4eb2ccd
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Aug 31, 2021
dec8cca
Add test_checkpoints.py
soham-chitnis10 Aug 31, 2021
ffcc4be
Merge branch 'Adversarial-Deep-Learning:main' into seeding-checkpoints
soham-chitnis10 Aug 31, 2021
99d2cb2
Merge pull request #3 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Aug 31, 2021
9b67d23
Add test_checkpoints.py
soham-chitnis10 Aug 31, 2021
fcdc43e
Merge branch 'seeding-checkpoints' of https://github.com/soham-chitni…
soham-chitnis10 Aug 31, 2021
309e147
Merge pull request #4 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Aug 31, 2021
d40b8d1
minor changes
soham-chitnis10 Sep 1, 2021
4aa5da4
Updated test_checkpoints.py
soham-chitnis10 Sep 1, 2021
960c334
Updated test_checkpoints.py
soham-chitnis10 Sep 1, 2021
f60f0c2
Merge pull request #5 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Sep 1, 2021
09dcdde
Updated test_checkpoints.py
soham-chitnis10 Sep 1, 2021
65130af
Merge pull request #6 from soham-chitnis10/seeding-checkpoints
soham-chitnis10 Sep 1, 2021
fcdbf57
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Sep 2, 2021
5b8427a
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Sep 7, 2021
33b92ac
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Oct 10, 2021
246a3dd
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Oct 11, 2021
4733b3e
Merge branch 'Adversarial-Deep-Learning:main' into main
soham-chitnis10 Oct 11, 2021
6c8004a
Added imagenet modification
soham-chitnis10 Oct 27, 2021
38e52e5
Updated ResizeConvAAE
soham-chitnis10 Oct 27, 2021
495b5ad
Merge pull request #7 from soham-chitnis10/atn
soham-chitnis10 Oct 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions code_soup/ch5/algorithms/atn.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,173 @@ def forward(self, x):
logits = self.classifier_model(adv_out + x)
softmax_logits = F.softmax(logits, dim=1)
return adv_out + x, softmax_logits


class BilinearUpsample(nn.Module):
def __init__(self, scale_factor):
super(BilinearUpsample, self).__init__()
self.scale_factor = scale_factor

def forward(self, x):
return F.interpolate(
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
)


class BaseDeconvAAE(SimpleAAE):
def __init__(
self,
classifier_model: torch.nn.Module,
pretrained_backbone: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
backbone_output_shape: list = [192, 35, 35],
):

if backbone_output_shape != [192, 35, 35]:
raise ValueError("Backbone output shape must be [192, 35, 35].")

super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

layers = [
pretrained_backbone,
nn.ZeroPad2d((1, 1, 1, 1)),
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ZeroPad2d((3, 2, 3, 2)),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh(),
]

self.atn = nn.ModuleList(layers)


class ResizeConvAAE(SimpleAAE):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):

super(ResizeConvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

layers = [
nn.Conv2d(3, 128, 5, padding=11),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(128, 256, 4, padding=11),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(256, 512, 3, padding=11),
nn.ReLU(),
BilinearUpsample(scale_factor=0.5),
nn.Conv2d(512, 512, 1, padding=11),
nn.ReLU(),
BilinearUpsample(scale_factor=2),
nn.Conv2d(512, 256, 3, padding=11),
nn.ReLU(),
BilinearUpsample(scale_factor=2),
nn.Conv2d(256, 128, 4, padding=11),
nn.ReLU(),
nn.ZeroPad2d((8, 8, 8, 8)),
nn.Conv2d(128, 3, 3, padding=11),
nn.Tanh(),
]

self.atn = nn.ModuleList(layers)


class ConvDeconvAAE(SimpleAAE):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):

super(ConvDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

layers = [
nn.Conv2d(3, 256, 3, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(256, 512, 3, stride=2, padding=2),
nn.ReLU(),
nn.Conv2d(512, 768, 3, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=2),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=2),
nn.ReLU(),
nn.ZeroPad2d((146, 145, 146, 145)),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh(),
]

self.atn = nn.ModuleList(layers)


class BaseDeconvPATN(SimplePATN):
def __init__(
self,
classifier_model: torch.nn.Module,
pretrained_backbone: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
backbone_output_shape: list = [192, 35, 35],
):

if backbone_output_shape != [192, 35, 35]:
raise ValueError("Backbone output shape must be [192, 35, 35].")

super(BaseDeconvPATN, self).__init__(classifier_model, target_idx, alpha, beta)

layers = [
pretrained_backbone,
nn.ZeroPad2d((1, 1, 1, 1)),
nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ZeroPad2d((3, 2, 3, 2)),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh(), # TODO: CHeck if right activation
]

self.atn = nn.ModuleList(layers)


class ConvFCPATN(SimplePATN):
def __init__(
self,
classifier_model: torch.nn.Module,
target_idx: int,
alpha: float = 1.5,
beta: float = 0.010,
):

super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta)

layers = [
nn.Conv2d(3, 512, 3, stride=2, padding=22),
nn.Conv2d(512, 256, 3, stride=2, padding=22),
nn.Conv2d(256, 128, 3, stride=2, padding=22),
nn.Flatten(),
nn.Linear(184832, 512),
nn.Linear(512, 268203),
nn.Tanh(),
]

self.atn = nn.ModuleList(layers)