Skip to content

Commit 0324bbf

Browse files
committed
rm testdata and add save option
1 parent 57898ac commit 0324bbf

File tree

1 file changed

+5
-25
lines changed

1 file changed

+5
-25
lines changed

main.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import multiprocessing
77
import math
88
import pandas as pd
9+
import numpy as np
10+
np.set_printoptions(threshold=np.inf)
911

1012
from crystalformer.src.utils import GLXYZAW_from_file, letter_to_number
1113
from crystalformer.src.elements import element_dict, element_list
@@ -72,6 +74,7 @@
7274
group.add_argument('--T1', type=float, default=None, help='temperature used for sampling the first atom type')
7375
group.add_argument('--num_io_process', type=int, default=40, help='number of process used in multiprocessing io')
7476
group.add_argument('--num_samples', type=int, default=1000, help='number of test samples')
77+
group.add_argument('--save_path', type=str, default=None, help='path to save the sampled structures')
7578
group.add_argument('--output_filename', type=str, default='output.csv', help='outfile to save sampled structures')
7679

7780
group = parser.add_argument_group('MCMC parameters')
@@ -96,7 +99,7 @@
9699
valid_data = GLXYZAW_from_file(args.valid_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process)
97100
else:
98101
assert (args.spacegroup is not None) # for inference we need to specify space group
99-
test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process)
102+
# test_data = GLXYZAW_from_file(args.test_path, args.atom_types, args.wyck_types, args.n_max, args.num_io_process)
100103

101104
# jnp.set_printoptions(threshold=jnp.inf) # print full array
102105
constraints = jnp.arange(0, args.n_max, 1)
@@ -189,7 +192,7 @@
189192
os.makedirs(output_path, exist_ok=True)
190193
print("Create directory for output: %s" % output_path)
191194
else:
192-
output_path = os.path.dirname(args.restore_path)
195+
output_path = os.path.dirname(args.save_path) if args.save_path else os.path.dirname(args.restore_path)
193196
print("Will output samples to: %s" % output_path)
194197

195198

@@ -227,29 +230,6 @@
227230
params, opt_state = train(key, optimizer, opt_state, loss_fn, params, epoch_finished, args.epochs, args.batchsize, train_data, valid_data, output_path, args.val_interval)
228231

229232
else:
230-
pass
231-
232-
print("\n========== Calculate the loss of test dataset ==========")
233-
import numpy as np
234-
np.set_printoptions(threshold=np.inf)
235-
236-
test_G, test_L, test_XYZ, test_A, test_W = test_data
237-
print (test_G.shape, test_L.shape, test_XYZ.shape, test_A.shape, test_W.shape)
238-
test_loss = 0
239-
num_samples = len(test_L)
240-
num_batches = math.ceil(num_samples / args.batchsize)
241-
for batch_idx in range(num_batches):
242-
start_idx = batch_idx * args.batchsize
243-
end_idx = min(start_idx + args.batchsize, num_samples)
244-
G, L, XYZ, A, W = test_G[start_idx:end_idx], \
245-
test_L[start_idx:end_idx], \
246-
test_XYZ[start_idx:end_idx], \
247-
test_A[start_idx:end_idx], \
248-
test_W[start_idx:end_idx]
249-
loss, _ = jax.jit(loss_fn, static_argnums=7)(params, key, G, L, XYZ, A, W, False)
250-
test_loss += loss
251-
test_loss = test_loss / num_batches
252-
print ("evaluating loss on test data:" , test_loss)
253233

254234
print("\n========== Start sampling ==========")
255235
jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice

0 commit comments

Comments
 (0)