-
Notifications
You must be signed in to change notification settings - Fork 0
/
tuning.py
91 lines (78 loc) · 2.65 KB
/
tuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from ax.plot.contour import plot_contour
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render
from torchvision import transforms
from torch.utils.data import DataLoader
from networks.zeoliteAdaDNN import AdaThreeLayerNet
from networks.zeoliteCNN import TinyConvNet
from train import train
from utils.data import ZeoStructDataset, PeriodicPadding, ToTensor
# Perform hyperparameter tuning on PyTorch models using FACEBOOK's Ax
# https://ax.dev/ <-- for further details on Ax
def evaluate(net, data_loader):
"""
Compute classification accuracy on provided dataset.
Args:
net: trained model
data_loader: DataLoader containing the evaluation set
Returns:
float: MSE
"""
net.eval()
loss = 0
total = 0
with torch.no_grad():
for sample in data_loader:
inputs, target = sample
target = target.view(-1, 1) # reshaping to match input size
outputs = net(inputs)
loss += (target - outputs)**2
total += 1
return {'MSE': (loss / total)}
# When I first tried Ax, I was trying to learn from the energy grid files
def train_evaluate(parameterization):
torch.manual_seed(12345)
net = TinyConvNet(output_size=(24, 24, 24))
train_set = ZeoStructDataset(
csv_file='/Data/Zeolites/batch_V2/train/train_kH.csv',
root_dir='/Data/Zeolites/batch_V2/train/',
transform=transforms.Compose([PeriodicPadding(), ToTensor()]),
)
val_set = ZeoStructDataset(
csv_file='/Data/Zeolites/batch_V2/val/val_kH.csv',
root_dir='/Data/Zeolites/batch_V2/val/',
transform=transforms.Compose([PeriodicPadding(), ToTensor()]),
)
val_loader = DataLoader(val_set, batch_size=1, shuffle=True)
post_train = {
'plot_loss': False,
'plot_weights': False,
'save_net': False,
}
net = train(net, train_data=train_set, batch_size=1, parameters=parameterization, post_train=post_train)
return evaluate(net=net, data_loader=val_loader)
parameters = [
{
'name': 'lr',
'type': 'range',
'bounds': [1e-6, 0.4],
'log_scale': True,
},
{
'name': 'weight_decay',
'type': 'range',
'bounds': [0.0, 1.0],
}
]
best_parameters, values, experiment, model = optimize(
parameters=parameters,
experiment_name='TinyConvNet tuning',
objective_name='minimization',
evaluation_function=train_evaluate,
minimize=True,
)
print(best_parameters)
means, covariances = values
print(means, covariances)
render(plot_contour(model=model, param_x='lr', param_y='weight_decay', metric_name='MSE'))