diff --git a/code_soup/common/perturbation.py b/code_soup/common/perturbation.py new file mode 100644 index 0000000..b2e023d --- /dev/null +++ b/code_soup/common/perturbation.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + + +class Perturbation(ABC): + """ + Docstring for Abstract Class Perturbation + """ + + @classmethod + @abstractmethod + def __init__(self): + pass diff --git a/code_soup/common/vision/datasets/__init__.py b/code_soup/common/vision/datasets/__init__.py index 13cab83..e69de29 100644 --- a/code_soup/common/vision/datasets/__init__.py +++ b/code_soup/common/vision/datasets/__init__.py @@ -1,6 +0,0 @@ -from code_soup.common.vision.datasets.image_classification import ( - ImageClassificationDataset, -) -from code_soup.common.vision.datasets.vision_dataset import ( # THE ABSTRACT DATASET CLASS - VisionDataset, -) diff --git a/code_soup/common/vision/models/__init__.py b/code_soup/common/vision/models/__init__.py index 6958db9..b52bba4 100644 --- a/code_soup/common/vision/models/__init__.py +++ b/code_soup/common/vision/models/__init__.py @@ -1,39 +1,3 @@ -from torchvision.models import ( - alexnet, - densenet121, - densenet161, - densenet169, - densenet201, - googlenet, - inception_v3, - mnasnet0_5, - mnasnet0_75, - mnasnet1_0, - mnasnet1_3, - mobilenet_v2, - mobilenet_v3_large, - mobilenet_v3_small, - resnet18, - resnet34, - resnet50, - resnet101, - resnet152, - resnext50_32x4d, - resnext101_32x8d, - shufflenet_v2_x0_5, - shufflenet_v2_x1_0, - shufflenet_v2_x1_5, - shufflenet_v2_x2_0, - squeezenet1_0, - squeezenet1_1, - vgg11, - vgg13, - vgg16, - vgg19, - wide_resnet50_2, - wide_resnet101_2, -) - from code_soup.common.vision.models.allconvnet import AllConvNet from code_soup.common.vision.models.nin import NIN from code_soup.common.vision.models.simple_cnn_classifier import SimpleCnnClassifier diff --git a/code_soup/common/vision/perturbations.py b/code_soup/common/vision/perturbations.py new file mode 100644 index 0000000..c68bc3c --- /dev/null +++ b/code_soup/common/vision/perturbations.py @@ -0,0 +1,39 @@ +from abc import abstractmethod +from typing import Union + +import numpy as np +import torch + +from code_soup.common.perturbation import Perturbation + + +class VisualPerturbation(Perturbation): + """ + Docstring for VisualPerturbations + """ + + def __init__( + self, + original: Union[np.ndarray, torch.Tensor], + perturbed: Union[np.ndarray, torch.Tensor], + ): + """ + Docstring + #Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if + """ + raise NotImplementedError + + def calculate_LPNorm(self, p: Union[int, str]): + raise NotImplementedError + + def calculate_PSNR(self): + raise NotImplementedError + + def calculate_RMSE(self): + raise NotImplementedError + + def calculate_SAM(self): + raise NotImplementedError + + def calculate_SRE(self): + raise NotImplementedError diff --git a/tests/test_common/test_vision/test_perturbations.py b/tests/test_common/test_vision/test_perturbations.py new file mode 100644 index 0000000..eb71992 --- /dev/null +++ b/tests/test_common/test_vision/test_perturbations.py @@ -0,0 +1,60 @@ +import random +import unittest + +import numpy as np +import torch +from torchvision.datasets.fakedata import FakeData +from torchvision.transforms import ToTensor + +from code_soup.common.vision.perturbations import VisualPerturbation + + +class TestVisualPerturbation(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(42) + np.random.seed(42) + random.seed(42) + df = FakeData(size=2, image_size=(3, 64, 64)) + a, b = tuple(df) + a, b = ToTensor()(a[0]).unsqueeze_(0), ToTensor()(b[0]).unsqueeze_(0) + cls.obj_tensor = VisualPerturbation(original=a, perturbed=b) + cls.obj_numpy = VisualPerturbation(original=a.numpy(), perturbed=b.numpy()) + + def test_LPNorm(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_LPNorm(p=1), 4143.0249, places=3 + ) + self.assertAlmostEqual( + TestVisualPerturbation.obj_numpy.calculate_LPNorm(p="fro"), + 45.6525, + places=3, + ) + + def test_PSNR(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_PSNR(), + 33.773994480876496, + places=3, + ) + + def test_RMSE(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_RMSE(), + 0.018409499898552895, + places=3, + ) + + def test_SAM(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_SAM(), + 89.34839413786915, + places=3, + ) + + def test_SRE(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_SRE(), + 41.36633261587073, + places=3, + )