-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
69 lines (44 loc) · 1.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
import pandas as pd
from VanillaTransformerModel import Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load data
train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')
# environment variable with wandb api key must be set
wandb.login()
run = wandb.init(project="VanillaTransformer")
# create dataset
train_dataset = VanillaTransformerDataset(train_df)
# create dataloader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# create model
model = Transformer(4, 6, 512, 8, 1024, 0.1,64, "gelu").to(device)
# create optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# create loss function
criterion = nn.MSELoss()
wandb.watch(model, criterion, log="all", log_freq=100)
# training loop
min_loss = 1e12
for epoch in range(10):
for batch in tqdm.tqdm(train_dataloader):
optimizer.zero_grad()
input = batch['input'].to(device)
target = batch['target'].to(device)
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
wandb.log({"loss": loss.item()})
if loss.item() < min_loss:
min_loss = loss.item()
torch.save(model.state_dict(), "model.pt")
print(f"Epoch: {epoch}, Loss: {loss.item()}")
wandb.log({"epoch": epoch, "loss": loss.item()})