|
6 | 6 | import multiprocessing |
7 | 7 | import math |
8 | 8 | import pandas as pd |
| 9 | +import numpy as np |
| 10 | +np.set_printoptions(threshold=np.inf) |
9 | 11 |
|
10 | 12 | from crystalformer.src.utils import GLXYZAW_from_file, letter_to_number |
11 | 13 | from crystalformer.src.elements import element_dict, element_list |
|
72 | 74 | group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type') |
73 | 75 | group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io') |
74 | 76 | group.add_argument('--num_samples', type=int, default=1000, help='number of test samples') |
| 77 | +group.add_argument('--save_path', type=str, default=None, help='path to save the sampled structures') |
75 | 78 | group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures') |
76 | 79 |
|
77 | 80 | group = parser.add_argument_group('MCMC parameters') |
|
96 | 99 | valid_data = GLXYZAW_from_file(args.valid_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process) |
97 | 100 | else: |
98 | 101 | assert (args.spacegroup is not None) # for inference we need to specify space group |
99 | | - test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process) |
| 102 | + # test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process) |
100 | 103 |
|
101 | 104 | # jnp.set_printoptions(threshold=jnp.inf) # print full array |
102 | 105 | constraints = jnp.arange(0, args.n_max, 1) |
|
189 | 192 | os.makedirs(output_path, exist_ok=True) |
190 | 193 | print("Create directory for output: %s" % output_path) |
191 | 194 | else: |
192 | | - output_path = os.path.dirname(args.restore_path) |
| 195 | + output_path = os.path.dirname(args.save_path) if args.save_path else os.path.dirname(args.restore_path) |
193 | 196 | print("Will output samples to: %s" % output_path) |
194 | 197 |
|
195 | 198 |
|
|
227 | 230 | params, opt_state = train(key, optimizer, opt_state, loss_fn, params, epoch_finished, args.epochs, args.batchsize, train_data, valid_data, output_path, args.val_interval) |
228 | 231 |
|
229 | 232 | else: |
230 | | - pass |
231 | | - |
232 | | - print("\n========== Calculate the loss of test dataset ==========") |
233 | | - import numpy as np |
234 | | - np.set_printoptions(threshold=np.inf) |
235 | | - |
236 | | - test_G, test_L, test_XYZ, test_A, test_W = test_data |
237 | | - print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape) |
238 | | - test_loss = 0 |
239 | | - num_samples = len(test_L) |
240 | | - num_batches = math.ceil(num_samples / args.batchsize) |
241 | | - for batch_idx in range(num_batches): |
242 | | - start_idx = batch_idx * args.batchsize |
243 | | - end_idx = min(start_idx + args.batchsize, num_samples) |
244 | | - G, L, XYZ, A, W = test_G[start_idx:end_idx], \ |
245 | | - test_L[start_idx:end_idx], \ |
246 | | - test_XYZ[start_idx:end_idx], \ |
247 | | - test_A[start_idx:end_idx], \ |
248 | | - test_W[start_idx:end_idx] |
249 | | - loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False) |
250 | | - test_loss += loss |
251 | | - test_loss = test_loss / num_batches |
252 | | - print ("evaluating loss on test data:" , test_loss) |
253 | 233 |
|
254 | 234 | print("\n========== Start sampling ==========") |
255 | 235 | jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice |
|
0 commit comments