|
1 |
| -from datasets import Mnist |
2 |
| -from models import Generator, Discriminator |
| 1 | +import argparse |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.optim as optim |
3 | 5 | import torchvision.transforms as transforms
|
4 | 6 | import torchvision.utils as vutils
|
5 |
| -import torch.optim as optim |
6 |
| -import torch |
7 |
| -import argparse |
| 7 | + |
| 8 | +from code_soup.ch5.datasets import MnistDataset |
| 9 | +from code_soup.ch5.models import Discriminator, Generator |
8 | 10 |
|
9 | 11 | parser = argparse.ArgumentParser(
|
10 | 12 | prog="mnist_gan.py", description="Train an MNIST GAN model"
|
@@ -52,7 +54,7 @@ def train_mnist_gan():
|
52 | 54 | transforms.Normalize((0.5,), (0.5,)),
|
53 | 55 | ]
|
54 | 56 | )
|
55 |
| - dataset = Mnist(transform=transform) |
| 57 | + dataset = MnistDataset(transform=transform) |
56 | 58 | dataloader = torch.utils.data.DataLoader(
|
57 | 59 | dataset, batch_size=dataloader_batch_size, shuffle=True
|
58 | 60 | )
|
@@ -92,7 +94,7 @@ def train_mnist_gan():
|
92 | 94 | errD_real.backward()
|
93 | 95 |
|
94 | 96 | D_x = output.mean().item()
|
95 |
| - ## Train with all-fake batch |
| 97 | + # Train with all-fake batch |
96 | 98 | # Generate batch of latent vectors
|
97 | 99 | noise = torch.randn(batch_size, latent_dims, device=device)
|
98 | 100 | # Generate fake image batch with G
|
@@ -137,9 +139,10 @@ def train_mnist_gan():
|
137 | 139 | D_G_z2,
|
138 | 140 | )
|
139 | 141 | )
|
140 |
| - #save model weights |
| 142 | + # save model weights |
141 | 143 | torch.save(discriminator.state_dict(), "./discriminator.pth")
|
142 | 144 | torch.save(generator.state_dict(), "./generator.pth")
|
143 | 145 |
|
| 146 | + |
144 | 147 | if __name__ == "__main__":
|
145 | 148 | train_mnist_gan()
|
0 commit comments