diff --git a/examples/other/node_clustering_tvgnn.py b/examples/other/node_clustering_tvgnn.py index 06a9a81f..11ed8d1d 100644 --- a/examples/other/node_clustering_tvgnn.py +++ b/examples/other/node_clustering_tvgnn.py @@ -7,18 +7,19 @@ """ import numpy as np -from tqdm import tqdm +import tensorflow as tf from sklearn.metrics.cluster import ( completeness_score, homogeneity_score, normalized_mutual_info_score, ) -import tensorflow as tf from tensorflow.keras import Model -from spektral.utils.sparse import sp_matrix_to_sp_tensor -from spektral.datasets.citation import Citation +from tqdm import tqdm + from spektral.datasets import DBLP -from spektral.layers import GTVConv, AsymCheegerCutPool +from spektral.datasets.citation import Citation +from spektral.layers import AsymCheegerCutPool, GTVConv +from spektral.utils.sparse import sp_matrix_to_sp_tensor tf.random.set_seed(1) @@ -30,12 +31,12 @@ mp_layers = 2 mp_activation = "elu" delta_coeff = 0.311 -epsilon=1e-3 +epsilon = 1e-3 mlp_hidden_channels = 256 mlp_hidden_layers = 1 mlp_activation = "relu" -totvar_coeff=0.785 -balance_coeff=0.514 +totvar_coeff = 0.785 +balance_coeff = 0.514 learning_rate = 1e-3 epochs = 500 @@ -77,13 +78,14 @@ def call(self, inputs): return s_pool + # Define the message-passing layers -MP_layers = [GTVConv( - mp_channels, - delta_coeff=delta_coeff, - epsilon=1e-3, - activation=mp_activation) -for _ in range(mp_layers)] +MP_layers = [ + GTVConv( + mp_channels, delta_coeff=delta_coeff, epsilon=1e-3, activation=mp_activation + ) + for _ in range(mp_layers) +] # Define the pooling layer pool_layer = AsymCheegerCutPool( @@ -92,7 +94,8 @@ def call(self, inputs): mlp_activation=mlp_activation, totvar_coeff=totvar_coeff, balance_coeff=balance_coeff, - return_selection=True) + return_selection=True, +) # Instantiate model and optimizer model = ClusteringModel(aggr=MP_layers, pool=pool_layer) @@ -110,6 +113,7 @@ def train_step(model, inputs): opt.apply_gradients(zip(gradients, model.trainable_variables)) return model.losses + A = sp_matrix_to_sp_tensor(A) inputs = [X, A] loss_history = [] @@ -126,4 +130,4 @@ def train_step(model, inputs): nmi = normalized_mutual_info_score(y, s_out) hom = homogeneity_score(y, s_out) com = completeness_score(y, s_out) -print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi)) \ No newline at end of file +print("Homogeneity: {:.3f}; Completeness: {:.3f}; NMI: {:.3f}".format(hom, com, nmi)) diff --git a/spektral/layers/convolutional/gtv_conv.py b/spektral/layers/convolutional/gtv_conv.py index ea469975..9b3c1a22 100644 --- a/spektral/layers/convolutional/gtv_conv.py +++ b/spektral/layers/convolutional/gtv_conv.py @@ -1,8 +1,10 @@ import tensorflow as tf from tensorflow.keras import backend as K + from spektral.layers import ops from spektral.layers.convolutional.conv import Conv + class GTVConv(Conv): r""" A graph total variation convolutional layer (GTVConv) from the paper @@ -132,43 +134,48 @@ def _call_single(self, x, a): if K.is_sparse(a): index_i = a.indices[:, 0] index_j = a.indices[:, 1] - + n_nodes = tf.shape(a, out_type=index_i.dtype)[0] - + # Compute absolute differences between neighbouring nodes - abs_diff = tf.math.abs(tf.transpose(tf.gather(x, index_i)) - - tf.transpose(tf.gather(x, index_j))) + abs_diff = tf.math.abs( + tf.transpose(tf.gather(x, index_i)) + - tf.transpose(tf.gather(x, index_j)) + ) abs_diff = tf.math.reduce_sum(abs_diff, axis=0) - + # Compute new adjacency matrix - gamma = tf.sparse.map_values(tf.multiply, - a, - 1 / tf.math.maximum(abs_diff, self.epsilon)) - + gamma = tf.sparse.map_values( + tf.multiply, a, 1 / tf.math.maximum(abs_diff, self.epsilon) + ) + # Compute degree matrix from gamma matrix - d_gamma = tf.sparse.SparseTensor(tf.stack([tf.range(n_nodes)] * 2, axis=1), - tf.sparse.reduce_sum(gamma, axis=-1), - [n_nodes, n_nodes]) - + d_gamma = tf.sparse.SparseTensor( + tf.stack([tf.range(n_nodes)] * 2, axis=1), + tf.sparse.reduce_sum(gamma, axis=-1), + [n_nodes, n_nodes], + ) + # Compute laplcian: L = D_gamma - Gamma - l = tf.sparse.add(d_gamma, tf.sparse.map_values( - tf.multiply, gamma, -1.)) - + l = tf.sparse.add(d_gamma, tf.sparse.map_values(tf.multiply, gamma, -1.0)) + # Compute adjusted laplacian: L_adjusted = I - delta*L - l = tf.sparse.add(tf.sparse.eye(n_nodes, dtype=x.dtype), tf.sparse.map_values( - tf.multiply, l, -self.delta_coeff)) - + l = tf.sparse.add( + tf.sparse.eye(n_nodes, dtype=x.dtype), + tf.sparse.map_values(tf.multiply, l, -self.delta_coeff), + ) + # Aggregate features with adjusted laplacian output = ops.modal_dot(l, x) - + else: n_nodes = tf.shape(a)[-1] - + abs_diff = tf.math.abs(x[:, tf.newaxis, :] - x) abs_diff = tf.reduce_sum(abs_diff, axis=-1) - + gamma = a / tf.math.maximum(abs_diff, self.epsilon) - + degrees = tf.math.reduce_sum(gamma, axis=-1) l = -gamma l = tf.linalg.set_diag(l, degrees - tf.linalg.diag_part(gamma)) @@ -180,9 +187,10 @@ def _call_single(self, x, a): def _call_batch(self, x, a): n_nodes = tf.shape(a)[-1] - - abs_diff = tf.reduce_sum(tf.math.abs(tf.expand_dims(x, 2) - - tf.expand_dims(x, 1)), axis = -1) + + abs_diff = tf.reduce_sum( + tf.math.abs(tf.expand_dims(x, 2) - tf.expand_dims(x, 1)), axis=-1 + ) gamma = a / tf.math.maximum(abs_diff, self.epsilon) @@ -192,11 +200,13 @@ def _call_batch(self, x, a): l = tf.eye(n_nodes, dtype=x.dtype) - self.delta_coeff * l output = tf.matmul(l, x) - + return output - + @property def config(self): - return {"channels": self.channels, - "delta_coeff": self.delta_coeff, - "epsilon": self.epsilon} \ No newline at end of file + return { + "channels": self.channels, + "delta_coeff": self.delta_coeff, + "epsilon": self.epsilon, + } diff --git a/spektral/layers/pooling/asym_cheeger_cut_pool.py b/spektral/layers/pooling/asym_cheeger_cut_pool.py index 5f7a8599..fe59e9af 100644 --- a/spektral/layers/pooling/asym_cheeger_cut_pool.py +++ b/spektral/layers/pooling/asym_cheeger_cut_pool.py @@ -1,10 +1,12 @@ import tensorflow as tf +import tensorflow.keras.backend as K from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense -import tensorflow.keras.backend as K + from spektral.layers import ops from spektral.layers.pooling.src import SRCPool + class AsymCheegerCutPool(SRCPool): r""" An Asymmetric Cheeger Cut Pooling layer from the paper @@ -151,7 +153,7 @@ def reduce(self, x, s, **kwargs): def connect(self, a, s, **kwargs): a_pool = ops.matmul_at_b_a(s, a) - + return a_pool def reduce_index(self, i, s, **kwargs): @@ -159,7 +161,7 @@ def reduce_index(self, i, s, **kwargs): i_pool = ops.repeat(i_mean, tf.ones_like(i_mean) * self.k) return i_pool - + def totvar_loss(self, a, s): if K.is_sparse(a): index_i = a.indices[:, 0] @@ -167,25 +169,38 @@ def totvar_loss(self, a, s): n_edges = tf.cast(len(a.values), dtype=s.dtype) - loss = tf.math.reduce_sum(a.values[:, tf.newaxis] * - tf.math.abs(tf.gather(s, index_i) - - tf.gather(s, index_j)), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a.values[:, tf.newaxis] + * tf.math.abs(tf.gather(s, index_i) - tf.gather(s, index_j)), + axis=(-2, -1), + ) else: - n_edges = tf.cast(tf.math.count_nonzero( - a, axis=(-2, -1)), dtype=s.dtype) + n_edges = tf.cast(tf.math.count_nonzero(a, axis=(-2, -1)), dtype=s.dtype) n_nodes = tf.shape(a)[-1] if K.ndim(a) == 3: - loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s[:, tf.newaxis, ...] - - tf.repeat(s[..., tf.newaxis, :], - n_nodes, axis=-2)), axis=-1), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a + * tf.math.reduce_sum( + tf.math.abs( + s[:, tf.newaxis, ...] + - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) + ), + axis=-1, + ), + axis=(-2, -1), + ) else: - loss = tf.math.reduce_sum(a * tf.math.reduce_sum(tf.math.abs(s - - tf.repeat(s[..., tf.newaxis, :], - n_nodes, axis=-2)), axis=-1), - axis=(-2, -1)) + loss = tf.math.reduce_sum( + a + * tf.math.reduce_sum( + tf.math.abs( + s - tf.repeat(s[..., tf.newaxis, :], n_nodes, axis=-2) + ), + axis=-1, + ), + axis=(-2, -1), + ) loss *= 1 / (2 * n_edges) @@ -196,15 +211,15 @@ def balance_loss(self, s): # k-quantile idx = tf.cast(tf.math.floor(n_nodes / self.k) + 1, dtype=tf.int32) - med = tf.math.top_k(tf.linalg.matrix_transpose(s), - k=idx).values[..., -1] + med = tf.math.top_k(tf.linalg.matrix_transpose(s), k=idx).values[..., -1] # Asymmetric l1-norm if K.ndim(s) == 2: loss = s - med else: loss = s - med[:, tf.newaxis, ...] - loss = ((tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + - (tf.cast(loss < 0, loss.dtype) * loss * -1.)) + loss = (tf.cast(loss >= 0, loss.dtype) * (self.k - 1) * loss) + ( + tf.cast(loss < 0, loss.dtype) * loss * -1.0 + ) loss = tf.math.reduce_sum(loss, axis=(-2, -1)) loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss) @@ -216,7 +231,7 @@ def get_config(self): "mlp_hidden": self.mlp_hidden, "mlp_activation": self.mlp_activation, "totvar_coeff": self.totvar_coeff, - "balance_coeff": self.balance_coeff + "balance_coeff": self.balance_coeff, } base_config = super().get_config() - return {**base_config, **config} \ No newline at end of file + return {**base_config, **config} diff --git a/tests/test_layers/convolutional/test_gtv_conv.py b/tests/test_layers/convolutional/test_gtv_conv.py index cbf60398..ef498348 100644 --- a/tests/test_layers/convolutional/test_gtv_conv.py +++ b/tests/test_layers/convolutional/test_gtv_conv.py @@ -5,7 +5,12 @@ config = { "layer": layers.GTVConv, "modes": [MODES["SINGLE"], MODES["BATCH"]], - "kwargs": {"channels": 8, "delta_coeff": 1.0, "epsilon": 0.001, "activation": "relu"}, + "kwargs": { + "channels": 8, + "delta_coeff": 1.0, + "epsilon": 0.001, + "activation": "relu", + }, "dense": True, "sparse": True, "edges": False, @@ -13,4 +18,4 @@ def test_layer(): - run_layer(config) \ No newline at end of file + run_layer(config) diff --git a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py index 468c5f04..3baf3be5 100644 --- a/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py +++ b/tests/test_layers/pooling/test_asym_cheeger_cut_pool.py @@ -4,7 +4,13 @@ config = { "layer": layers.AsymCheegerCutPool, "modes": [MODES["SINGLE"], MODES["BATCH"]], - "kwargs": {"k": 5, "return_selection": True, "mlp_hidden": [32], "totvar_coeff": 1.0, "balance_coeff": 1.0}, + "kwargs": { + "k": 5, + "return_selection": True, + "mlp_hidden": [32], + "totvar_coeff": 1.0, + "balance_coeff": 1.0, + }, "dense": True, "sparse": True, }