Skip to content

Commit

Permalink
Merge pull request #122 from elmahyai/master
Browse files Browse the repository at this point in the history
Make attention trainable in A3TGCN and make it support batches
  • Loading branch information
benedekrozemberczki authored Dec 26, 2021
2 parents 678752c + 40dfeb0 commit 88cc03c
Show file tree
Hide file tree
Showing 6 changed files with 412 additions and 3 deletions.
176 changes: 176 additions & 0 deletions examples/recurrent/a3tgcn2_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# I published a working notebook of this example at https://www.kaggle.com/elmahy/a3t-gcn-for-traffic-forecasting

# The contribution makes training possible because it support batches of data


import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric_temporal.nn.recurrent import A3TGCN2
# GPU support
DEVICE = torch.device('cuda') # cuda
shuffle=True
batch_size = 32



#Dataset
#Traffic forecasting dataset based on Los Angeles Metropolitan traffic
#207 loop detectors on highways
#March 2012 - June 2012
#From the paper: Diffusion Convolutional Recurrent Neural Network


from torch_geometric_temporal.dataset import METRLADatasetLoader
loader = METRLADatasetLoader()
dataset = loader.get_dataset(num_timesteps_in=12, num_timesteps_out=12)
print("Dataset type: ", dataset)
print("Number of samples / sequences: ", len(set(dataset)))


# Visualize traffic over time
sensor_number = 1
hours = 24
sensor_labels = [bucket.y[sensor_number][0].item() for bucket in list(dataset)[:hours]]
plt.plot(sensor_labels)

# Train test split

from torch_geometric_temporal.signal import temporal_signal_split
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

print("Number of train buckets: ", len(set(train_dataset)))
print("Number of test buckets: ", len(set(test_dataset)))


# Creating Dataloaders

train_input = np.array(train_dataset.features) # (27399, 207, 2, 12)
train_target = np.array(train_dataset.targets) # (27399, 207, 12)
train_x_tensor = torch.from_numpy(train_input).type(torch.FloatTensor).to(DEVICE) # (B, N, F, T)
train_target_tensor = torch.from_numpy(train_target).type(torch.FloatTensor).to(DEVICE) # (B, N, T)
train_dataset_new = torch.utils.data.TensorDataset(train_x_tensor, train_target_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset_new, batch_size=batch_size, shuffle=shuffle,drop_last=True)


test_input = np.array(test_dataset.features) # (, 207, 2, 12)
test_target = np.array(test_dataset.targets) # (, 207, 12)
test_x_tensor = torch.from_numpy(test_input).type(torch.FloatTensor).to(DEVICE) # (B, N, F, T)
test_target_tensor = torch.from_numpy(test_target).type(torch.FloatTensor).to(DEVICE) # (B, N, T)
test_dataset_new = torch.utils.data.TensorDataset(test_x_tensor, test_target_tensor)
test_loader = torch.utils.data.DataLoader(test_dataset_new, batch_size=batch_size, shuffle=shuffle,drop_last=True)



# Making the model
class TemporalGNN(torch.nn.Module):
def __init__(self, node_features, periods, batch_size):
super(TemporalGNN, self).__init__()
# Attention Temporal Graph Convolutional Cell
self.tgnn = A3TGCN2(in_channels=node_features, out_channels=32, periods=periods,batch_size=batch_size) # node_features=2, periods=12
# Equals single-shot prediction
self.linear = torch.nn.Linear(32, periods)

def forward(self, x, edge_index):
"""
x = Node features for T time steps
edge_index = Graph edge indices
"""
h = self.tgnn(x, edge_index) # x [b, 207, 2, 12] returns h [b, 207, 12]
h = F.relu(h)
h = self.linear(h)
return h

TemporalGNN(node_features=2, periods=12, batch_size=2)



# Create model and optimizers
model = TemporalGNN(node_features=2, periods=12, batch_size=batch_size).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()


print('Net\'s state_dict:')
total_param = 0
for param_tensor in model.state_dict():
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
total_param += np.prod(model.state_dict()[param_tensor].size())
print('Net\'s total params:', total_param)
#--------------------------------------------------
print('Optimizer\'s state_dict:') # If you notice here the Attention is a trainable parameter
for var_name in optimizer.state_dict():
print(var_name, '\t', optimizer.state_dict()[var_name])



# Loading the graph once because it's a static graph

for snapshot in train_dataset:
static_edge_index = snapshot.edge_index.to(DEVICE)
break;



# Training the model
model.train()

for epoch in range(30):
step = 0
loss_list = []
for encoder_inputs, labels in train_loader:
y_hat = model(encoder_inputs, static_edge_index) # Get model predictions
loss = loss_fn(y_hat, labels) # Mean squared error #loss = torch.mean((y_hat-labels)**2) sqrt to change it to rmse
loss.backward()
optimizer.step()
optimizer.zero_grad()
step= step+ 1
loss_list.append(loss.item())
if step % 100 == 0 :
print(sum(loss_list)/len(loss_list))
print("Epoch {} train RMSE: {:.4f}".format(epoch, sum(loss_list)/len(loss_list)))


## Evaluation

#- Lets get some sample predictions for a specific horizon (e.g. 288/12 = 24 hours)
#- The model always gets one hour and needs to predict the next hour

model.eval()
step = 0
# Store for analysis
total_loss = []
for encoder_inputs, labels in test_loader:
# Get model predictions
y_hat = model(encoder_inputs, static_edge_index)
# Mean squared error
loss = loss_fn(y_hat, labels)
total_loss.append(loss.item())
# Store for analysis below
#test_labels.append(labels)
#predictions.append(y_hat)


print("Test MSE: {:.4f}".format(sum(total_loss)/len(total_loss)))


### Visualization

#- The further away the point in time is, the worse the predictions get
#- Predictions shape: [num_data_points, num_sensors, num_timesteps]


sensor = 123
timestep = 11
preds = np.asarray([pred[sensor][timestep].detach().cpu().numpy() for pred in y_hat])
labs = np.asarray([label[sensor][timestep].cpu().numpy() for label in labels])
print("Data points:,", preds.shape)

plt.figure(figsize=(20,5))
sns.lineplot(data=preds, label="pred")
sns.lineplot(data=labs, label="true")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"torch_scatter",
"torch_cluster",
"torch_spline_conv",
"torch_geometric",
"torch_geometric==1.7.0",
"numpy",
"scipy",
"tqdm",
Expand Down
51 changes: 51 additions & 0 deletions test/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
EvolveGCNO,
TGCN,
A3TGCN,
A3TGCN2,
MPNNLSTM,
)

Expand All @@ -22,6 +23,8 @@ def create_mock_data(number_of_nodes, edge_per_node, in_channels):
return X, edge_index




def create_mock_attention_data(number_of_nodes, edge_per_node, in_channels, periods):
"""
Creating a mock stacked feature matrix and edge index.
Expand All @@ -34,6 +37,19 @@ def create_mock_attention_data(number_of_nodes, edge_per_node, in_channels, peri
return X, edge_index



def create_mock_attention_batch_data(number_of_nodes, edge_per_node, in_channels, periods, batch_size):
"""
Creating a mock stacked feature matrix in batches and edge index.
"""
graph = nx.watts_strogatz_graph(number_of_nodes, edge_per_node, 0.5)
edge_index = torch.LongTensor(np.array([edge for edge in graph.edges()]).T)
X = torch.FloatTensor(
np.random.uniform(-1, 1, (batch_size, number_of_nodes, in_channels, periods))
)
return X, edge_index


def create_mock_states(number_of_nodes, out_channels):
"""
Creating mock hidden and cell states.
Expand Down Expand Up @@ -222,6 +238,41 @@ def test_a3tgcn_layer():
assert H.shape == (number_of_nodes, out_channels)


def test_a3tgcn2_layer():
"""
Testing the A3TGCN2 Layer by adding a batch index.
"""
number_of_nodes = 100
edge_per_node = 10
in_channels = 64
out_channels = 16
periods = 7
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X, edge_index = create_mock_attention_batch_data(
number_of_nodes, edge_per_node, in_channels, periods, batch_size
)
X = X.to(device)
edge_index = edge_index.to(device)
edge_weight = create_mock_edge_weight(edge_index).to(device)

layer = A3TGCN2(
in_channels=in_channels, out_channels=out_channels, periods=periods, batch_size=batch_size
).to(device)

H = layer(X, edge_index)

assert H.shape == (batch_size, number_of_nodes, out_channels)

H = layer(X, edge_index, edge_weight)

assert H.shape == (batch_size, number_of_nodes, out_channels)

H = layer(X, edge_index, edge_weight, H)

assert H.shape == (batch_size, number_of_nodes, out_channels)


def test_dcrnn_layer():
"""
Testing the DCRNN Layer.
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric_temporal/nn/recurrent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .evolvegcno import EvolveGCNO
from .dcrnn import DCRNN
from .temporalgcn import TGCN
from .temporalgcn import TGCN2
from .attentiontemporalgcn import A3TGCN
from .attentiontemporalgcn import A3TGCN2
from .mpnn_lstm import MPNNLSTM
from .agcrn import AGCRN
84 changes: 82 additions & 2 deletions torch_geometric_temporal/nn/recurrent/attentiontemporalgcn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from .temporalgcn import TGCN
from .temporalgcn import TGCN2
from torch_geometric.nn import GCNConv


Expand All @@ -24,7 +25,7 @@ def __init__(
periods: int,
improved: bool = False,
cached: bool = False,
add_self_loops: bool = True,
add_self_loops: bool = True
):
super(A3TGCN, self).__init__()

Expand All @@ -44,7 +45,8 @@ def _setup_layers(self):
cached=self.cached,
add_self_loops=self.add_self_loops,
)
self._attention = torch.empty(self.periods)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
torch.nn.init.uniform_(self._attention)

def forward(
Expand Down Expand Up @@ -75,3 +77,81 @@ def forward(
X[:, :, period], edge_index, edge_weight, H
)
return H_accum



class A3TGCN2(torch.nn.Module):
r"""An implementation THAT SUPPORTS BATCHES of the Attention Temporal Graph Convolutional Cell.
For details see this paper: `"A3T-GCN: Attention Temporal Graph Convolutional
Network for Traffic Forecasting." <https://arxiv.org/abs/2006.11583>`_
Args:
in_channels (int): Number of input features.
out_channels (int): Number of output features.
periods (int): Number of time periods.
improved (bool): Stronger self loops (default :obj:`False`).
cached (bool): Caching the message weights (default :obj:`False`).
add_self_loops (bool): Adding self-loops for smoothing (default :obj:`True`).
"""

def __init__(
self,
in_channels: int,
out_channels: int,
periods: int,
batch_size:int,
improved: bool = False,
cached: bool = False,
add_self_loops: bool = True):
super(A3TGCN2, self).__init__()

self.in_channels = in_channels # 2
self.out_channels = out_channels # 32
self.periods = periods # 12
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.batch_size = batch_size
self._setup_layers()

def _setup_layers(self):
self._base_tgcn = TGCN2(
in_channels=self.in_channels,
out_channels=self.out_channels,
batch_size=self.batch_size,
improved=self.improved,
cached=self.cached,
add_self_loops=self.add_self_loops)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self._attention = torch.nn.Parameter(torch.empty(self.periods, device=device))
torch.nn.init.uniform_(self._attention)

def forward(
self,
X: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: torch.FloatTensor = None,
H: torch.FloatTensor = None
) -> torch.FloatTensor:
"""
Making a forward pass. If edge weights are not present the forward pass
defaults to an unweighted graph. If the hidden state matrix is not present
when the forward pass is called it is initialized with zeros.
Arg types:
* **X** (PyTorch Float Tensor): Node features for T time periods.
* **edge_index** (PyTorch Long Tensor): Graph edge indices.
* **edge_weight** (PyTorch Long Tensor, optional)*: Edge weight vector.
* **H** (PyTorch Float Tensor, optional): Hidden state matrix for all nodes.
Return types:
* **H** (PyTorch Float Tensor): Hidden state matrix for all nodes.
"""
H_accum = 0
probs = torch.nn.functional.softmax(self._attention, dim=0)
for period in range(self.periods):

H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32]

return H_accum
Loading

0 comments on commit 88cc03c

Please sign in to comment.