-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
76 lines (65 loc) · 2.46 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
###########################
# Paper Implementation: NormVAE on NeuroImaging data
# Authors: Sandesh Katakam
###########################
from data_preprocessing import load_and_standardize_data, DataBuilder, DataLoader, trainloader, testloader
from model import customLoss, Autoencoder
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable
import argparse
import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", default=1500, type=int) # Number of epochs
parser.add_argument("--gensamples", default=10, type=int) # Number of samples
parser.add_argument("--output_format", default = "csv", type = str) # Ouput_format of the reconstructed samples
args = parser.parse_args()
epochs = args.epochs # Number of epochs to train
D_in = trainloader.dataset.x.shape[1]
H = 50
H2 = 12
log_interval = 50
val_losses = []
train_losses = []
test_losses = []
loss_mse = customLoss()
model = Autoencoder(D_in, H, H2).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train(epoch):
model.train()
train_loss = 0
for batch_idx, data in enumerate(trainloader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_mse(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if epoch % 200 == 0:
print('====> Epoch: {} Average training loss: {:.4f}'.format(
epoch, train_loss / len(trainloader.dataset)))
train_losses.append(train_loss / len(trainloader.dataset))
def test(epoch):
with torch.no_grad():
test_loss = 0
for batch_idx, data in enumerate(testloader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_mse(recon_batch, data, mu, logvar)
test_loss += loss.item()
if epoch % 200 == 0:
print('====> Epoch: {} Average test loss: {:.4f}'.format(
epoch, test_loss / len(testloader.dataset)))
test_losses.append(test_loss / len(testloader.dataset))
for epoch in range(1, epochs + 1):
train(epoch)
test(epoch)