Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup #11

Merged
merged 4 commits into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ venv.bak/
dmypy.json

# Pyre type checker
.pyre/
.pyre/
3 changes: 2 additions & 1 deletion code_soup/ch5/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from code_soup.ch5.dummy import add_nums
from code_soup.ch5.datasets import MnistDataset
from code_soup.ch5.models import Discriminator, Generator
1 change: 1 addition & 0 deletions code_soup/ch5/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from code_soup.ch5.datasets.mnist import MnistDataset
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from torch.utils.data import Dataset
import torchvision.datasets as datasets
from torch.utils.data import Dataset


class Mnist(Dataset):
class MnistDataset(Dataset):
def __init__(self, transform=None):

self.train_data = datasets.MNIST(
Expand Down
6 changes: 0 additions & 6 deletions code_soup/ch5/dummy.py

This file was deleted.

19 changes: 11 additions & 8 deletions code_soup/chapter_5/mnist_gan.py → code_soup/ch5/mnist_gan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datasets import Mnist
from models import Generator, Discriminator
import argparse

import torch
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.optim as optim
import torch
import argparse

from code_soup.ch5.datasets import MnistDataset
from code_soup.ch5.models import Discriminator, Generator

parser = argparse.ArgumentParser(
prog="mnist_gan.py", description="Train an MNIST GAN model"
Expand Down Expand Up @@ -52,7 +54,7 @@ def train_mnist_gan():
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = Mnist(transform=transform)
dataset = MnistDataset(transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=dataloader_batch_size, shuffle=True
)
Expand Down Expand Up @@ -92,7 +94,7 @@ def train_mnist_gan():
errD_real.backward()

D_x = output.mean().item()
## Train with all-fake batch
# Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(batch_size, latent_dims, device=device)
# Generate fake image batch with G
Expand Down Expand Up @@ -137,9 +139,10 @@ def train_mnist_gan():
D_G_z2,
)
)
#save model weights
# save model weights
torch.save(discriminator.state_dict(), "./discriminator.pth")
torch.save(generator.state_dict(), "./generator.pth")


if __name__ == "__main__":
train_mnist_gan()
2 changes: 2 additions & 0 deletions code_soup/ch5/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from code_soup.ch5.models.discriminator import Discriminator
from code_soup.ch5.models.generator import Generator
1 change: 0 additions & 1 deletion code_soup/chapter_5/datasets/__init__.py

This file was deleted.

2 changes: 0 additions & 2 deletions code_soup/chapter_5/models/__init__.py

This file was deleted.

Empty file removed code_soup/readme.md
Empty file.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy==1.21.1
Pillow==8.3.1
torch==1.9.0
torchvision==0.10.0
torchvision==0.10.0
3 changes: 2 additions & 1 deletion tests/test_ch5/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from tests.test_ch5.test_dummy import test_add_nums
from tests.test_ch5.test_datasets import TestMnistDataset
from tests.test_ch5.test_models import TestDiscriminatorModel, TestGeneratorModel
2 changes: 1 addition & 1 deletion tests/test_ch5/test_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mnist_test import TestMnistDataset
from tests.test_ch5.test_datasets.mnist_test import TestMnistDataset
9 changes: 6 additions & 3 deletions tests/test_ch5/test_datasets/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
from pathlib import Path

import torch
import torchvision.datasets as datasets
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.datasets as datasets
from code_soup.chapter_5.datasets import *

from code_soup.ch5.datasets import MnistDataset


class TestMnistDataset(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
mnist_dataset = Mnist(transform=transforms.Compose([transforms.ToTensor()]))
mnist_dataset = MnistDataset(
transform=transforms.Compose([transforms.ToTensor()])
)
mnist_dataloader = DataLoader(mnist_dataset, batch_size=64, shuffle=False)
cls.samples = next(iter(mnist_dataloader))

Expand Down
6 changes: 0 additions & 6 deletions tests/test_ch5/test_dummy.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_ch5/test_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .discriminator_test import TestDiscriminatorModel
from .generator_test import TestGeneratorModel
from tests.test_ch5.test_models.discriminator_test import TestDiscriminatorModel
from tests.test_ch5.test_models.generator_test import TestGeneratorModel
6 changes: 4 additions & 2 deletions tests/test_ch5/test_models/discriminator_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest

import torch
import torch.nn as nn
import unittest
from code_soup.chapter_5.models import Discriminator

from code_soup.ch5.models import Discriminator


class TestDiscriminatorModel(unittest.TestCase):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_ch5/test_models/generator_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
import torch.nn as nn

import torch
from code_soup.chapter_5.models import Generator
import torch.nn as nn

from code_soup.ch5.models import Generator


class TestGeneratorModel(unittest.TestCase):
Expand Down