Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #72

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions code_soup/common/perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstract method to be under common-utils now



class Perturbation(ABC):
"""
Docstring for Abstract Class Perturbation
"""

@classmethod
@abstractmethod
def __init__(self):
pass
6 changes: 0 additions & 6 deletions code_soup/common/vision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from code_soup.common.vision.datasets.image_classification import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you be importing this directly in the files where required?

ImageClassificationDataset,
)
from code_soup.common.vision.datasets.vision_dataset import ( # THE ABSTRACT DATASET CLASS
VisionDataset,
)
36 changes: 0 additions & 36 deletions code_soup/common/vision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,3 @@
from torchvision.models import (
alexnet,
densenet121,
densenet161,
densenet169,
densenet201,
googlenet,
inception_v3,
mnasnet0_5,
mnasnet0_75,
mnasnet1_0,
mnasnet1_3,
mobilenet_v2,
mobilenet_v3_large,
mobilenet_v3_small,
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
resnext50_32x4d,
resnext101_32x8d,
shufflenet_v2_x0_5,
shufflenet_v2_x1_0,
shufflenet_v2_x1_5,
shufflenet_v2_x2_0,
squeezenet1_0,
squeezenet1_1,
vgg11,
vgg13,
vgg16,
vgg19,
wide_resnet50_2,
wide_resnet101_2,
)

from code_soup.common.vision.models.allconvnet import AllConvNet
from code_soup.common.vision.models.nin import NIN
from code_soup.common.vision.models.simple_cnn_classifier import SimpleCnnClassifier
39 changes: 39 additions & 0 deletions code_soup/common/vision/perturbations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from abc import abstractmethod
from typing import Union

import numpy as np
import torch

from code_soup.common.perturbation import Perturbation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file according to the factoring will come under vision-> utils. Please do the needful

Also the common perturbation file should come under common->utils



class VisualPerturbation(Perturbation):
"""
Docstring for VisualPerturbations
"""

def __init__(
self,
original: Union[np.ndarray, torch.Tensor],
perturbed: Union[np.ndarray, torch.Tensor],
):
"""
Docstring
#Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if
"""
raise NotImplementedError

def calculate_LPNorm(self, p: Union[int, str]):
raise NotImplementedError

def calculate_PSNR(self):
raise NotImplementedError

def calculate_RMSE(self):
raise NotImplementedError

def calculate_SAM(self):
raise NotImplementedError

def calculate_SRE(self):
raise NotImplementedError
60 changes: 60 additions & 0 deletions tests/test_common/test_vision/test_perturbations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random
import unittest

import numpy as np
import torch
from torchvision.datasets.fakedata import FakeData
from torchvision.transforms import ToTensor

from code_soup.common.vision.perturbations import VisualPerturbation


class TestVisualPerturbation(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
df = FakeData(size=2, image_size=(3, 64, 64))
a, b = tuple(df)
a, b = ToTensor()(a[0]).unsqueeze_(0), ToTensor()(b[0]).unsqueeze_(0)
cls.obj_tensor = VisualPerturbation(original=a, perturbed=b)
cls.obj_numpy = VisualPerturbation(original=a.numpy(), perturbed=b.numpy())

def test_LPNorm(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_LPNorm(p=1), 4143.0249, places=3
)
self.assertAlmostEqual(
TestVisualPerturbation.obj_numpy.calculate_LPNorm(p="fro"),
45.6525,
places=3,
)

def test_PSNR(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_PSNR(),
33.773994480876496,
places=3,
)

def test_RMSE(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_RMSE(),
0.018409499898552895,
places=3,
)

def test_SAM(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_SAM(),
89.34839413786915,
places=3,
)

def test_SRE(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_SRE(),
41.36633261587073,
places=3,
)