Skip to content

Commit 9cc5627

Browse files
authored
Cleanup (#11)
1 parent 09d01a5 commit 9cc5627

20 files changed

+39
-39
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,4 @@ venv.bak/
126126
dmypy.json
127127

128128
# Pyre type checker
129-
.pyre/
129+
.pyre/

code_soup/ch5/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from code_soup.ch5.dummy import add_nums
1+
from code_soup.ch5.datasets import MnistDataset
2+
from code_soup.ch5.models import Discriminator, Generator

code_soup/ch5/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from code_soup.ch5.datasets.mnist import MnistDataset

code_soup/chapter_5/datasets/mnist.py code_soup/ch5/datasets/mnist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from torch.utils.data import Dataset
21
import torchvision.datasets as datasets
2+
from torch.utils.data import Dataset
33

44

5-
class Mnist(Dataset):
5+
class MnistDataset(Dataset):
66
def __init__(self, transform=None):
77

88
self.train_data = datasets.MNIST(

code_soup/ch5/dummy.py

-6
This file was deleted.

code_soup/chapter_5/mnist_gan.py code_soup/ch5/mnist_gan.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from datasets import Mnist
2-
from models import Generator, Discriminator
1+
import argparse
2+
3+
import torch
4+
import torch.optim as optim
35
import torchvision.transforms as transforms
46
import torchvision.utils as vutils
5-
import torch.optim as optim
6-
import torch
7-
import argparse
7+
8+
from code_soup.ch5.datasets import MnistDataset
9+
from code_soup.ch5.models import Discriminator, Generator
810

911
parser = argparse.ArgumentParser(
1012
prog="mnist_gan.py", description="Train an MNIST GAN model"
@@ -52,7 +54,7 @@ def train_mnist_gan():
5254
transforms.Normalize((0.5,), (0.5,)),
5355
]
5456
)
55-
dataset = Mnist(transform=transform)
57+
dataset = MnistDataset(transform=transform)
5658
dataloader = torch.utils.data.DataLoader(
5759
dataset, batch_size=dataloader_batch_size, shuffle=True
5860
)
@@ -92,7 +94,7 @@ def train_mnist_gan():
9294
errD_real.backward()
9395

9496
D_x = output.mean().item()
95-
## Train with all-fake batch
97+
# Train with all-fake batch
9698
# Generate batch of latent vectors
9799
noise = torch.randn(batch_size, latent_dims, device=device)
98100
# Generate fake image batch with G
@@ -137,9 +139,10 @@ def train_mnist_gan():
137139
D_G_z2,
138140
)
139141
)
140-
#save model weights
142+
# save model weights
141143
torch.save(discriminator.state_dict(), "./discriminator.pth")
142144
torch.save(generator.state_dict(), "./generator.pth")
143145

146+
144147
if __name__ == "__main__":
145148
train_mnist_gan()

code_soup/ch5/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from code_soup.ch5.models.discriminator import Discriminator
2+
from code_soup.ch5.models.generator import Generator
File renamed without changes.

code_soup/chapter_5/datasets/__init__.py

-1
This file was deleted.

code_soup/chapter_5/models/__init__.py

-2
This file was deleted.

code_soup/readme.md

Whitespace-only changes.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
numpy==1.21.1
22
Pillow==8.3.1
33
torch==1.9.0
4-
torchvision==0.10.0
4+
torchvision==0.10.0

tests/test_ch5/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from tests.test_ch5.test_dummy import test_add_nums
1+
from tests.test_ch5.test_datasets import TestMnistDataset
2+
from tests.test_ch5.test_models import TestDiscriminatorModel, TestGeneratorModel
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .mnist_test import TestMnistDataset
1+
from tests.test_ch5.test_datasets.mnist_test import TestMnistDataset

tests/test_ch5/test_datasets/mnist_test.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
from pathlib import Path
33

44
import torch
5+
import torchvision.datasets as datasets
56
from PIL import Image
67
from torch.utils.data import DataLoader, Dataset
78
from torchvision import transforms
8-
import torchvision.datasets as datasets
9-
from code_soup.chapter_5.datasets import *
9+
10+
from code_soup.ch5.datasets import MnistDataset
1011

1112

1213
class TestMnistDataset(unittest.TestCase):
1314
@classmethod
1415
def setUpClass(cls) -> None:
15-
mnist_dataset = Mnist(transform=transforms.Compose([transforms.ToTensor()]))
16+
mnist_dataset = MnistDataset(
17+
transform=transforms.Compose([transforms.ToTensor()])
18+
)
1619
mnist_dataloader = DataLoader(mnist_dataset, batch_size=64, shuffle=False)
1720
cls.samples = next(iter(mnist_dataloader))
1821

tests/test_ch5/test_dummy.py

-6
This file was deleted.
+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .discriminator_test import TestDiscriminatorModel
2-
from .generator_test import TestGeneratorModel
1+
from tests.test_ch5.test_models.discriminator_test import TestDiscriminatorModel
2+
from tests.test_ch5.test_models.generator_test import TestGeneratorModel

tests/test_ch5/test_models/discriminator_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import unittest
2+
13
import torch
24
import torch.nn as nn
3-
import unittest
4-
from code_soup.chapter_5.models import Discriminator
5+
6+
from code_soup.ch5.models import Discriminator
57

68

79
class TestDiscriminatorModel(unittest.TestCase):

tests/test_ch5/test_models/generator_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
2-
import torch.nn as nn
2+
33
import torch
4-
from code_soup.chapter_5.models import Generator
4+
import torch.nn as nn
5+
6+
from code_soup.ch5.models import Generator
57

68

79
class TestGeneratorModel(unittest.TestCase):

0 commit comments

Comments
 (0)