-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
62 lines (54 loc) · 1.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import numpy as np
from datetime import datetime
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
batch_size = 128
num_epochs = 100
seed = 1
out_dir = './vae_2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# modelとloss遷移の保存
class VAE(nn.Module):
def __init__(self, z_dim):
super(VAE, self).__init__()
self.dense_enc1 = nn.Linear(28*28, 200)
self.dense_enc2 = nn.Linear(200, 200)
self.dense_encmean = nn.Linear(200, z_dim)
self.dense_encvar = nn.Linear(200, z_dim)
self.dense_dec1 = nn.Linear(z_dim, 200)
self.dense_dec2 = nn.Linear(200, 200)
self.dense_dec3 = nn.Linear(200, 28*28)
def _encoder(self, x):
x = F.relu(self.dense_enc1(x))
x = F.relu(self.dense_enc2(x))
mean = self.dense_encmean(x)
var = F.softplus(self.dense_encvar(x))
return mean, var
def _sample_z(self, mean, var):
epsilon = torch.randn(mean.shape).to(device)
return mean + torch.sqrt(var) * epsilon
def _decoder(self, z):
x = F.relu(self.dense_dec1(z))
x = F.relu(self.dense_dec2(x))
x = torch.sigmoid(self.dense_dec3(x))
return x
def forward(self, x):
mean, var = self._encoder(x)
z = self._sample_z(mean, var)
x = self._decoder(z)
return x, z
def loss(self, x):
mean, var = self._encoder(x)
KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var))
z = self._sample_z(mean, var)
y = self._decoder(z)
reconstruction = torch.mean(torch.sum(x * torch.log(y) + (1 - x) * torch.log(1 - y)))
lower_bound = [-KL, reconstruction]
return -sum(lower_bound)