diff --git a/code_soup/ch5/algorithms/atn.py b/code_soup/ch5/algorithms/atn.py index d6669ef..d521b79 100644 --- a/code_soup/ch5/algorithms/atn.py +++ b/code_soup/ch5/algorithms/atn.py @@ -331,3 +331,173 @@ def forward(self, x): logits = self.classifier_model(adv_out + x) softmax_logits = F.softmax(logits, dim=1) return adv_out + x, softmax_logits + + +class BilinearUpsample(nn.Module): + def __init__(self, scale_factor): + super(BilinearUpsample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + return F.interpolate( + x, scale_factor=self.scale_factor, mode="bilinear", align_corners=True + ) + + +class BaseDeconvAAE(SimpleAAE): + def __init__( + self, + classifier_model: torch.nn.Module, + pretrained_backbone: torch.nn.Module, + target_idx: int, + alpha: float = 1.5, + beta: float = 0.010, + backbone_output_shape: list = [192, 35, 35], + ): + + if backbone_output_shape != [192, 35, 35]: + raise ValueError("Backbone output shape must be [192, 35, 35].") + + super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta) + + layers = [ + pretrained_backbone, + nn.ZeroPad2d((1, 1, 1, 1)), + nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.ZeroPad2d((3, 2, 3, 2)), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh(), + ] + + self.atn = nn.ModuleList(layers) + + +class ResizeConvAAE(SimpleAAE): + def __init__( + self, + classifier_model: torch.nn.Module, + target_idx: int, + alpha: float = 1.5, + beta: float = 0.010, + ): + + super(ResizeConvAAE, self).__init__(classifier_model, target_idx, alpha, beta) + + layers = [ + nn.Conv2d(3, 128, 5, padding=11), + nn.ReLU(), + BilinearUpsample(scale_factor=0.5), + nn.Conv2d(128, 256, 4, padding=11), + nn.ReLU(), + BilinearUpsample(scale_factor=0.5), + nn.Conv2d(256, 512, 3, padding=11), + nn.ReLU(), + BilinearUpsample(scale_factor=0.5), + nn.Conv2d(512, 512, 1, padding=11), + nn.ReLU(), + BilinearUpsample(scale_factor=2), + nn.Conv2d(512, 256, 3, padding=11), + nn.ReLU(), + BilinearUpsample(scale_factor=2), + nn.Conv2d(256, 128, 4, padding=11), + nn.ReLU(), + nn.ZeroPad2d((8, 8, 8, 8)), + nn.Conv2d(128, 3, 3, padding=11), + nn.Tanh(), + ] + + self.atn = nn.ModuleList(layers) + + +class ConvDeconvAAE(SimpleAAE): + def __init__( + self, + classifier_model: torch.nn.Module, + target_idx: int, + alpha: float = 1.5, + beta: float = 0.010, + ): + + super(ConvDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta) + + layers = [ + nn.Conv2d(3, 256, 3, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(256, 512, 3, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(512, 768, 3, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=2), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=2), + nn.ReLU(), + nn.ZeroPad2d((146, 145, 146, 145)), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh(), + ] + + self.atn = nn.ModuleList(layers) + + +class BaseDeconvPATN(SimplePATN): + def __init__( + self, + classifier_model: torch.nn.Module, + pretrained_backbone: torch.nn.Module, + target_idx: int, + alpha: float = 1.5, + beta: float = 0.010, + backbone_output_shape: list = [192, 35, 35], + ): + + if backbone_output_shape != [192, 35, 35]: + raise ValueError("Backbone output shape must be [192, 35, 35].") + + super(BaseDeconvPATN, self).__init__(classifier_model, target_idx, alpha, beta) + + layers = [ + pretrained_backbone, + nn.ZeroPad2d((1, 1, 1, 1)), + nn.ConvTranspose2d(192, 512, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.ZeroPad2d((3, 2, 3, 2)), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=1, padding=1), + nn.Tanh(), # TODO: CHeck if right activation + ] + + self.atn = nn.ModuleList(layers) + + +class ConvFCPATN(SimplePATN): + def __init__( + self, + classifier_model: torch.nn.Module, + target_idx: int, + alpha: float = 1.5, + beta: float = 0.010, + ): + + super(BaseDeconvAAE, self).__init__(classifier_model, target_idx, alpha, beta) + + layers = [ + nn.Conv2d(3, 512, 3, stride=2, padding=22), + nn.Conv2d(512, 256, 3, stride=2, padding=22), + nn.Conv2d(256, 128, 3, stride=2, padding=22), + nn.Flatten(), + nn.Linear(184832, 512), + nn.Linear(512, 268203), + nn.Tanh(), + ] + + self.atn = nn.ModuleList(layers)