You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’ve been working on training a PyTorch model but I’m running into some issues. Despite following the usual training steps, my model doesn’t seem to be learning properly. I suspect there might be a problem with how I’m updating the model parameters, but I can’t seem to figure out what’s wrong.
for epoch in range(200):
epoch_loss = 0
actual_values = []
predicted_values = []
model2.train()
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
h, c = None, None
optimizer.zero_grad()
y_hat = model2(batch_x, edge_index_sectors, edge_weights, h, c)
loss = criterion(y_hat, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item() / batch_x.shape[0]
# Store the actual and predicted values for plotting
actual_values.extend(batch_y.detach().cpu().numpy().flatten())
predicted_values.extend(y_hat.detach().cpu().numpy().flatten())
epoch_loss /= len(train_loader)
print(f'Epoch {epoch+1}, Loss: {epoch_loss}')
# Plot the graph of predicted vs actual values after each epoch
plt.figure(figsize=(10,5))
plt.plot(actual_values, label='Actual')
plt.plot(predicted_values, label='Predicted')
plt.legend()
plt.show()
The text was updated successfully, but these errors were encountered:
Hello everyone,
I’ve been working on training a PyTorch model but I’m running into some issues. Despite following the usual training steps, my model doesn’t seem to be learning properly. I suspect there might be a problem with how I’m updating the model parameters, but I can’t seem to figure out what’s wrong.
Here’s the relevant part of my code:
class RecurrentGCN2(torch.nn.Module):
def init(self, node_features):
super(RecurrentGCN2, self).init()
self.recurrent = LRGCN(node_features, 256, 1, 1)
self.layers = nn.Sequential(
torch.nn.Linear(256, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 1))
import matplotlib.pyplot as plt
for epoch in range(200):
epoch_loss = 0
actual_values = []
predicted_values = []
The text was updated successfully, but these errors were encountered: