diff --git a/GCN/GCN_jax.py b/GCN/GCN_jax.py new file mode 100644 index 0000000..28dc1f4 --- /dev/null +++ b/GCN/GCN_jax.py @@ -0,0 +1,123 @@ +import jax +import optax +import jax.numpy as jnp +from jax.scipy.linalg import inv +import matplotlib.pyplot as plt + +from celluloid import Camera + +plt.rcParams['animation.ffmpeg_path'] = '/usr/local/bin/ffmpeg' + + +A=jnp.array([[0,1,1,1,1,1,1,1,1,0,1,1,1,1,0,0,0,1,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0], + [1,0,1,1,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,1,0,1,0,0,0,0,0,0,0,0,1,0,0,0], + [1,1,0,1,0,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,1,0], + [1,1,1,0,0,0,0,1,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1], + [0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1], + [1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,1,0,0,1,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,1,0,0], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,1,0,0], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1], + [0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,1], + [0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1], + [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,1], + [0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1], + [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,0,0,0,1,1], + [0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,1,0,1,0,1,1,0,0,0,0,0,1,1,1,0,1], + [0,0,0,0,0,0,0,0,1,1,0,0,0,1,1,1,0,0,1,1,1,0,1,1,0,0,1,1,1,1,1,1,1,0] + ],dtype = jnp.float32) + +target=jnp.array([0,-1,-1,-1, -1, -1, -1, -1,-1,-1,-1,-1, -1, -1, -1, -1,-1,-1,-1,-1, -1, -1, -1, -1,-1,-1,-1,-1, -1, -1, -1, -1,-1,1],dtype = jnp.float32) + +X=jnp.eye(A.shape[0],dtype = jnp.float32) + +def GCNConv(A,X,W): + A_hat = A + jnp.eye(A.shape[0]) + D = jnp.diag(jnp.sum(A,1)) + D = jnp.sqrt(inv(D)) + A_hat = jnp.matmul(jnp.matmul(D,A_hat),D) + A_hat = jnp.matmul(jnp.matmul(A_hat,X),W) + return jax.nn.relu(A_hat) + +seed = 0 + +def init_GCNConv(layer_widths,key): + params = [] + first_key = key + for in_width, out_width in zip(layer_widths[:-1],layer_widths[1:]): + first_key, second_key = jax.random.split(first_key) + params.append( + jax.random.uniform(second_key, shape = (in_width,out_width)) + ) + return params + +rng = jax.random.PRNGKey(seed) + +GCN_params = init_GCNConv([X.shape[0],10,2],rng) + +def GCN_predict(params,x,a): + out = x + for w in params: + out = GCNConv(a,out,w) + return out + +num_epochs = 200 +lr = 0.01 +fig = plt.figure() +camera = Camera(fig) + + + +def loss(params,x,a,l): + predictions = GCN_predict(params,x,a) + gt_labels = jax.nn.one_hot(l[l != -1], 2) + loss = optax.softmax_cross_entropy(predictions[l != -1],gt_labels) + return loss.sum() + +def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params: + opt_state = optimizer.init(params) + + # @jax.jit + def step(params, opt_state, x,a,l): + loss_value, grads = jax.value_and_grad(loss)(params, x,a,l) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss_value + + for epoch in range(num_epochs): + params, opt_state, _ = step(params, opt_state, X,A, target) + l=(GCN_predict(params,X,A)); + + plt.scatter(l[:,0],l[:,1],c=[0, 0, 0, 0 ,0 ,0 ,0, 0, 1, 1, 0 ,0, 0, 0, 1 ,1 ,0 ,0 ,1, 0, 1, 0 ,1 ,1, 1, 1, 1 ,1 ,1, 1, 1, 1, 1, 1 ]) + for i in range(l.shape[0]): + text_plot = plt.text(l[i,0], l[i,1], str(i+1)) + + camera.snap() + print(epoch) + + + return params + +optimizer = optax.sgd(learning_rate=1e-2,momentum = 0.9) +params = fit(GCN_params, optimizer) + +animation = camera.animate(blit=False, interval=150) +animation.save('./img/train_karate_animation_jax.gif', writer='ffmpeg', fps=60) \ No newline at end of file diff --git a/GCN/img/train_karate_animation_jax.gif b/GCN/img/train_karate_animation_jax.gif new file mode 100644 index 0000000..500741d Binary files /dev/null and b/GCN/img/train_karate_animation_jax.gif differ