Skip to content

Commit 6dcd859

Browse files
committed
fix CI bugs
1 parent abfeebe commit 6dcd859

File tree

3 files changed

+18
-99
lines changed

3 files changed

+18
-99
lines changed

main.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -229,48 +229,27 @@
229229
else:
230230
pass
231231

232-
print("\n========== Print out some test data for the given space group ==========")
232+
print("\n========== Calculate the loss of test dataset ==========")
233233
import numpy as np
234234
np.set_printoptions(threshold=np.inf)
235235

236-
G, L, XYZ, A, W = test_data
237-
print (G.shape, L.shape, XYZ.shape, A.shape, W.shape)
238-
239-
idx = jnp.where(G==args.spacegroup,size=5)
240-
G = G[idx]
241-
L = L[idx]
242-
XYZ = XYZ[idx]
243-
A = A[idx]
244-
W = W[idx]
245-
246-
num_sites = jnp.sum(A!=0, axis=1)
247-
print ("num_sites:", num_sites)
248-
@jax.vmap
249-
def lookup(G, W):
250-
return mult_table[G-1, W] # (n_max, )
251-
M = lookup(G, W) # (batchsize, n_max)
252-
num_atoms = M.sum(axis=-1)
253-
print ("num_atoms:", num_atoms)
254-
255-
print ("G:", G)
256-
print ("A:\n", A)
257-
for a in A:
258-
print([element_list[i] for i in a])
259-
print ("W:\n",W)
260-
print ("XYZ:\n",XYZ)
261-
262-
outputs = jax.vmap(transformer, (None, None, 0, 0, 0, 0, 0, None), (0))(params, key, G, XYZ, A, W, M, False)
263-
print ("outputs.shape", outputs.shape)
264-
265-
h_al = outputs[:, 1::5, :] # (:, n_max, :)
266-
a_logit = h_al[:, :, :args.atom_types]
267-
l_logit, mu, sigma = jnp.split(h_al[jnp.arange(h_al.shape[0]), num_sites,
268-
args.atom_types:args.atom_types+args.Kl+2*6*args.Kl],
269-
[args.Kl, args.Kl+6*args.Kl], axis=-1)
270-
print ("L:\n",L)
271-
print ("exp(l_logit):\n", jnp.exp(l_logit))
272-
print ("mu:\n", mu.reshape(-1, args.Kl, 6))
273-
print ("sigma:\n", sigma.reshape(-1, args.Kl, 6))
236+
test_G, test_L, test_XYZ, test_A, test_W = test_data
237+
print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape)
238+
test_loss = 0
239+
num_samples = len(test_L)
240+
num_batches = math.ceil(num_samples / args.batchsize)
241+
for batch_idx in range(num_batches):
242+
start_idx = batch_idx * args.batchsize
243+
end_idx = min(start_idx + args.batchsize, num_samples)
244+
G, L, XYZ, A, W = test_G[start_idx:end_idx], \
245+
test_L[start_idx:end_idx], \
246+
test_XYZ[start_idx:end_idx], \
247+
test_A[start_idx:end_idx], \
248+
test_W[start_idx:end_idx]
249+
loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False)
250+
test_loss += loss
251+
test_loss = test_loss / num_batches
252+
print ("evaluating loss on test data:" , test_loss)
274253

275254
print("\n========== Start sampling ==========")
276255
jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice

tests/test_project.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tests/test_util.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
 (0)