-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring #72
base: main
Are you sure you want to change the base?
Refactoring #72
Changes from all commits
1abb489
be37e77
48e548e
c38b8d1
cf2679a
e28454b
3aae645
95383d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Perturbation(ABC): | ||
""" | ||
Docstring for Abstract Class Perturbation | ||
""" | ||
|
||
@classmethod | ||
@abstractmethod | ||
def __init__(self): | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +0,0 @@ | ||
from code_soup.common.vision.datasets.image_classification import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will you be importing this directly in the files where required? |
||
ImageClassificationDataset, | ||
) | ||
from code_soup.common.vision.datasets.vision_dataset import ( # THE ABSTRACT DATASET CLASS | ||
VisionDataset, | ||
) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,3 @@ | ||
from torchvision.models import ( | ||
alexnet, | ||
densenet121, | ||
densenet161, | ||
densenet169, | ||
densenet201, | ||
googlenet, | ||
inception_v3, | ||
mnasnet0_5, | ||
mnasnet0_75, | ||
mnasnet1_0, | ||
mnasnet1_3, | ||
mobilenet_v2, | ||
mobilenet_v3_large, | ||
mobilenet_v3_small, | ||
resnet18, | ||
resnet34, | ||
resnet50, | ||
resnet101, | ||
resnet152, | ||
resnext50_32x4d, | ||
resnext101_32x8d, | ||
shufflenet_v2_x0_5, | ||
shufflenet_v2_x1_0, | ||
shufflenet_v2_x1_5, | ||
shufflenet_v2_x2_0, | ||
squeezenet1_0, | ||
squeezenet1_1, | ||
vgg11, | ||
vgg13, | ||
vgg16, | ||
vgg19, | ||
wide_resnet50_2, | ||
wide_resnet101_2, | ||
) | ||
|
||
from code_soup.common.vision.models.allconvnet import AllConvNet | ||
from code_soup.common.vision.models.nin import NIN | ||
from code_soup.common.vision.models.simple_cnn_classifier import SimpleCnnClassifier |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from abc import abstractmethod | ||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from code_soup.common.perturbation import Perturbation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file according to the factoring will come under vision-> utils. Please do the needful Also the common perturbation file should come under common->utils |
||
|
||
|
||
class VisualPerturbation(Perturbation): | ||
""" | ||
Docstring for VisualPerturbations | ||
""" | ||
|
||
def __init__( | ||
self, | ||
original: Union[np.ndarray, torch.Tensor], | ||
perturbed: Union[np.ndarray, torch.Tensor], | ||
): | ||
""" | ||
Docstring | ||
#Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if | ||
""" | ||
raise NotImplementedError | ||
|
||
def calculate_LPNorm(self, p: Union[int, str]): | ||
raise NotImplementedError | ||
|
||
def calculate_PSNR(self): | ||
raise NotImplementedError | ||
|
||
def calculate_RMSE(self): | ||
raise NotImplementedError | ||
|
||
def calculate_SAM(self): | ||
raise NotImplementedError | ||
|
||
def calculate_SRE(self): | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import random | ||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from torchvision.datasets.fakedata import FakeData | ||
from torchvision.transforms import ToTensor | ||
|
||
from code_soup.common.vision.perturbations import VisualPerturbation | ||
|
||
|
||
class TestVisualPerturbation(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
torch.manual_seed(42) | ||
np.random.seed(42) | ||
random.seed(42) | ||
df = FakeData(size=2, image_size=(3, 64, 64)) | ||
a, b = tuple(df) | ||
a, b = ToTensor()(a[0]).unsqueeze_(0), ToTensor()(b[0]).unsqueeze_(0) | ||
cls.obj_tensor = VisualPerturbation(original=a, perturbed=b) | ||
cls.obj_numpy = VisualPerturbation(original=a.numpy(), perturbed=b.numpy()) | ||
|
||
def test_LPNorm(self): | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_tensor.calculate_LPNorm(p=1), 4143.0249, places=3 | ||
) | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_numpy.calculate_LPNorm(p="fro"), | ||
45.6525, | ||
places=3, | ||
) | ||
|
||
def test_PSNR(self): | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_tensor.calculate_PSNR(), | ||
33.773994480876496, | ||
places=3, | ||
) | ||
|
||
def test_RMSE(self): | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_tensor.calculate_RMSE(), | ||
0.018409499898552895, | ||
places=3, | ||
) | ||
|
||
def test_SAM(self): | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_tensor.calculate_SAM(), | ||
89.34839413786915, | ||
places=3, | ||
) | ||
|
||
def test_SRE(self): | ||
self.assertAlmostEqual( | ||
TestVisualPerturbation.obj_tensor.calculate_SRE(), | ||
41.36633261587073, | ||
places=3, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This abstract method to be under common-utils now