|
229 | 229 | else: |
230 | 230 | pass |
231 | 231 |
|
232 | | - print("\n========== Print out some test data for the given space group ==========") |
| 232 | + print("\n========== Calculate the loss of test dataset ==========") |
233 | 233 | import numpy as np |
234 | 234 | np.set_printoptions(threshold=np.inf) |
235 | 235 |
|
236 | | - G, L, XYZ, A, W = test_data |
237 | | - print (G.shape, L.shape, XYZ.shape, A.shape, W.shape) |
238 | | - |
239 | | - idx = jnp.where(G==args.spacegroup,size=5) |
240 | | - G = G[idx] |
241 | | - L = L[idx] |
242 | | - XYZ = XYZ[idx] |
243 | | - A = A[idx] |
244 | | - W = W[idx] |
245 | | - |
246 | | - num_sites = jnp.sum(A!=0, axis=1) |
247 | | - print ("num_sites:", num_sites) |
248 | | - @jax.vmap |
249 | | - def lookup(G, W): |
250 | | - return mult_table[G-1, W] # (n_max, ) |
251 | | - M = lookup(G, W) # (batchsize, n_max) |
252 | | - num_atoms = M.sum(axis=-1) |
253 | | - print ("num_atoms:", num_atoms) |
254 | | - |
255 | | - print ("G:", G) |
256 | | - print ("A:\n", A) |
257 | | - for a in A: |
258 | | - print([element_list[i] for i in a]) |
259 | | - print ("W:\n",W) |
260 | | - print ("XYZ:\n",XYZ) |
261 | | - |
262 | | - outputs = jax.vmap(transformer, (None, None, 0, 0, 0, 0, 0, None), (0))(params, key, G, XYZ, A, W, M, False) |
263 | | - print ("outputs.shape", outputs.shape) |
264 | | - |
265 | | - h_al = outputs[:, 1::5, :] # (:, n_max, :) |
266 | | - a_logit = h_al[:, :, :args.atom_types] |
267 | | - l_logit, mu, sigma = jnp.split(h_al[jnp.arange(h_al.shape[0]), num_sites, |
268 | | - args.atom_types:args.atom_types+args.Kl+2*6*args.Kl], |
269 | | - [args.Kl, args.Kl+6*args.Kl], axis=-1) |
270 | | - print ("L:\n",L) |
271 | | - print ("exp(l_logit):\n", jnp.exp(l_logit)) |
272 | | - print ("mu:\n", mu.reshape(-1, args.Kl, 6)) |
273 | | - print ("sigma:\n", sigma.reshape(-1, args.Kl, 6)) |
| 236 | + test_G, test_L, test_XYZ, test_A, test_W = test_data |
| 237 | + print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape) |
| 238 | + test_loss = 0 |
| 239 | + num_samples = len(test_L) |
| 240 | + num_batches = math.ceil(num_samples / args.batchsize) |
| 241 | + for batch_idx in range(num_batches): |
| 242 | + start_idx = batch_idx * args.batchsize |
| 243 | + end_idx = min(start_idx + args.batchsize, num_samples) |
| 244 | + G, L, XYZ, A, W = test_G[start_idx:end_idx], \ |
| 245 | + test_L[start_idx:end_idx], \ |
| 246 | + test_XYZ[start_idx:end_idx], \ |
| 247 | + test_A[start_idx:end_idx], \ |
| 248 | + test_W[start_idx:end_idx] |
| 249 | + loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False) |
| 250 | + test_loss += loss |
| 251 | + test_loss = test_loss / num_batches |
| 252 | + print ("evaluating loss on test data:" , test_loss) |
274 | 253 |
|
275 | 254 | print("\n========== Start sampling ==========") |
276 | 255 | jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice |
|
0 commit comments