Skip to content

Commit abfeebe

Browse files
committed
update to v0.4.2
1 parent d4ba484 commit abfeebe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3041
-884
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,7 @@
22
job*
33
*.out
44
__pycache__/
5+
data/
6+
experimental/
7+
*.ipynb
8+
*.egg-info/

README.md

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ crystal space, which is crucial for data and compute efficient generative modeli
1717

1818
- [Contents](#contents)
1919
- [Model card](#model-card)
20+
- [Status](#status)
2021
- [Get Started](#get-started)
2122
- [Installation](#installation)
2223
- [CPU installation](#cpu-installation)
@@ -27,6 +28,9 @@ crystal space, which is crucial for data and compute efficient generative modeli
2728
- [train](#train)
2829
- [sample](#sample)
2930
- [evaluate](#evaluate)
31+
- [Reinforcement Fine-tuning](#reinforcement-fine-tuning)
32+
- [$E\_{hull}$ Reward](#e_hull-reward)
33+
- [Dielectric FoM Reward](#dielectric-fom-reward)
3034
- [How to cite](#how-to-cite)
3135

3236
## Model card
@@ -44,6 +48,16 @@ The model is an autoregressive transformer for the space group conditioned cryst
4448

4549
We only consider symmetry inequivalent atoms. The remaining atoms are restored based on the space group and Wyckoff letter information. Note that there is a natural alphabetical ordering for the Wyckoff letters, starting with 'a' for a position with the site-symmetry group of maximal order and ending with the highest letter for the general position. The sampling procedure starts from higher symmetry sites (with smaller multiplicities) and then goes on to lower symmetry ones (with larger multiplicities). Only for the cases where discrete Wyckoff letters can not fully determine the structure, one needs to further consider factional coordinates in the loss or sampling.
4650

51+
## Status
52+
53+
Major milestones are summarized below.
54+
- v0.4.2 : Add implementation of direct preference optimization.
55+
- v0.4.1 : Replace the absolute positional embedding with the Rotary Positional Embedding (RoPE).
56+
- v0.4 : Add reinforcement learning (proximal policy optimization).
57+
- v0.3 : Add conditional generation in the plug-and-play manner.
58+
- v0.2 : Add Markov chain Monte Carlo (MCMC) sampling for template-based structure generation.
59+
- v0.1 : Initial implementations of crystalline material generation conditioned on the space group.
60+
4761
## Get Started
4862

4963
**Notebooks**: The quickest way to get started with _CrystalFormer_ is our notebooks in the Google Colab and Bohrium (Chinese version) platforms:
@@ -88,7 +102,7 @@ pip install -r requirements.txt
88102

89103
## Available Weights
90104

91-
We release the weights of the model trained on the MP-20 dataset. More details can be seen in the [model](./model/README.md) folder.
105+
We release the weights of the model trained on the MP-20 dataset and Alex-20 dataset. More details can be seen in the [model](./model/README.md) folder.
92106

93107
## How to run
94108

@@ -163,10 +177,55 @@ Note that the training, test, and generated datasets should contain the structur
163177

164178
More details about the post-processing can be seen in the [scripts](./scripts/README.md) folder.
165179

180+
## Reinforcement Fine-tuning
181+
182+
### $E_{hull}$ Reward
183+
184+
```bash
185+
train_ppo --folder ./data/\
186+
--restore_path YOUR_PATH\
187+
--valid_path YOUR_PATH/alex_20/val.csv\
188+
--test_path YOUR_PATH/alex_20/train.csv\
189+
--reward ehull\
190+
--convex_path YOUR_PATH/convex_hull_pbe_2023.12.29.json.bz2\
191+
--mlff_model orb\
192+
--mlff_path YOUR_PATH/orb-v2-20241011.ckpt
193+
```
194+
195+
- `folder`: the folder to save the model and logs
196+
- `restore_path`: the path to the pre-trained model weights
197+
- `valid_path`: the path to the validation dataset
198+
- `test_path`: the path to the test dataset. The space group distribution will be loaded from this dataset and used for the sampling in the reinforcement learning fine-tuning
199+
- `reward`: the reward function to use, `ehull` means the energy above the convex hull
200+
- `convex_path`: the path to the convex hull data, which is used to calculate the $E_{hull}$. Only used when the reward is `ehull`
201+
- `mlff_model`: the machine learning force field model to predict the total energy. We support [`orb`](https://github.com/orbital-materials/orb-models) and [`MACE`](https://github.com/ACEsuit/mace) models for the $E_{hull}$ reward
202+
- `mlff_path`: the path to load the checkpoint of the machine learning force field model
203+
204+
### Dielectric FoM Reward
205+
206+
```bash
207+
train_ppo --folder ./data/\
208+
--restore_path YOUR_PATH\
209+
--valid_path YOUR_PATH/alex_20/val.csv\
210+
--test_path YOUR_PATH/alex_20/train.csv\
211+
--reward dielectric\
212+
--mlff_model matgl\
213+
--mlff_path YOUR_PATH/model1,YOUR_PATH/model2
214+
```
215+
216+
- `folder`: the folder to save the model and logs
217+
- `restore_path`: the path to the pre-trained model weights
218+
- `valid_path`: the path to the validation dataset
219+
- `test_path`: the path to the test dataset. The space group distribution will be loaded from this dataset and used for the sampling in the reinforcement learning fine-tuning
220+
- `reward`: the reward function to use, `dielectric` means the dielectric figure of merit (FoM), which is the product of the total dielectric constant and the band gap
221+
- `mlff_model`: the machine learning force field model to predict the total energy. We only support models in [`matgl`](https://github.com/materialsvirtuallab/matgl) for the dielectric reward
222+
- `mlff_path`: the path to load the checkpoint of the machine learning force field model. Note that you need to provide the model paths for the total dielectric constant and band gap, separated by the `,`
223+
224+
166225
## How to cite
167226

168227
```bibtex
169-
@misc{cao2024space,
228+
@article{cao2024space,
170229
title={Space Group Informed Transformer for Crystalline Materials Generation},
171230
author={Zhendong Cao and Xiaoshan Luo and Jian Lv and Lei Wang},
172231
year={2024},
@@ -176,4 +235,16 @@ More details about the post-processing can be seen in the [scripts](./scripts/RE
176235
}
177236
```
178237

238+
```bibtex
239+
@article{cao2025crystalformerrl,
240+
title={CrystalFormer-RL: Reinforcement Fine-Tuning for Materials Design},
241+
author={Zhendong Cao and Lei Wang},
242+
year={2025},
243+
eprint={2504.02367},
244+
archivePrefix={arXiv},
245+
primaryClass={cond-mat.mtrl-sci},
246+
url={https://arxiv.org/abs/2504.02367},
247+
}
248+
```
249+
179250
**Note**: This project is unrelated to https://github.com/omron-sinicx/crystalformer with the same name.

crystalformer/cli/__init__.py

Whitespace-only changes.

classifier.py renamed to crystalformer/cli/classifier.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,36 @@ def get_labels(csv_file, label_col):
2020
labels = jnp.array(labels, dtype=float)
2121
return labels
2222

23+
def GLXYZAW_from_sample(spg, test_path):
24+
### read from generated data
25+
from ast import literal_eval
26+
from crystalformer.src.wyckoff import mult_table
2327

24-
if __name__ == "__main__":
28+
test_data = pd.read_csv(test_path)
29+
L, XYZ, A, W = test_data['L'], test_data['X'], test_data['A'], test_data['W']
30+
L = L.apply(lambda x: literal_eval(x))
31+
XYZ = XYZ.apply(lambda x: literal_eval(x))
32+
A = A.apply(lambda x: literal_eval(x))
33+
W = W.apply(lambda x: literal_eval(x))
34+
35+
# convert array of list to numpy ndarray
36+
G = jnp.array([spg]*len(L))
37+
L = jnp.array(L.tolist())
38+
XYZ = jnp.array(XYZ.tolist())
39+
A = jnp.array(A.tolist())
40+
W = jnp.array(W.tolist())
41+
42+
M = jax.vmap(lambda g, w: mult_table[g-1, w], in_axes=(0, 0))(G, W) # (batchsize, n_max)
43+
num_atoms = jnp.sum(M, axis=1)
44+
length, angle = jnp.split(L, 2, axis=-1)
45+
length = length/num_atoms[:, None]**(1/3)
46+
angle = angle * (jnp.pi / 180) # to rad
47+
L = jnp.concatenate([length, angle], axis=-1)
48+
49+
return G, L, XYZ, A, W
50+
51+
52+
def main():
2553

2654
import argparse
2755
parser = argparse.ArgumentParser(description='')
@@ -30,6 +58,7 @@ def get_labels(csv_file, label_col):
3058
group.add_argument('--train_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/train.csv', help='')
3159
group.add_argument('--valid_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/val.csv', help='')
3260
group.add_argument('--test_path', default='/data/zdcao/crystal_gpt/dataset/mp_20/test.csv', help='')
61+
group.add_argument('--spacegroup', type=int, default=None, help='The space group number')
3362
group.add_argument('--property', default='band_gap', help='The property to predict')
3463
group.add_argument('--num_io_process', type=int, default=40, help='number of io processes')
3564

@@ -82,11 +111,13 @@ def get_labels(csv_file, label_col):
82111
valid_data = (*valid_data, valid_labels)
83112

84113
else:
85-
test_data = GLXYZAW_from_file(args.test_path, args.atom_types,
86-
args.wyck_types, args.n_max, args.num_io_process)
87-
test_labels = get_labels(args.test_path, args.property)
88-
89-
test_data = (*test_data, test_labels)
114+
if args.spacegroup == None:
115+
G, L, XYZ, A, W = GLXYZAW_from_file(args.test_path, args.atom_types,
116+
args.wyck_types, args.n_max, args.num_io_process)
117+
test_labels = get_labels(args.test_path, args.property)
118+
119+
else:
120+
G, L, XYZ, A, W = GLXYZAW_from_sample(args.spacegroup, args.test_path)
90121

91122
################### Model #############################
92123
transformer_params, state, transformer = make_transformer(key, args.Nf, args.Kx, args.Kl, args.n_max,
@@ -146,12 +177,16 @@ def get_labels(csv_file, label_col):
146177
params, opt_state = train(subkey, optimizer, opt_state, loss_fn, params, state, epoch_finished, args.epochs, args.batchsize, train_data, valid_data, output_path)
147178

148179
elif args.optimizer == 'none':
149-
G, L, XYZ, A, W, labels = test_data
180+
150181
y = jax.vmap(forward_fn,
151182
in_axes=(None, None, None, 0, 0, 0, 0, 0, None)
152183
)(params, state, key, G, L, XYZ, A, W, False)
153184

154185
jnp.save(args.output_path, y)
155186

156187
else:
157-
raise NotImplementedError(f"Optimizer {args.optimizer} not implemented")
188+
raise NotImplementedError(f"Optimizer {args.optimizer} not implemented")
189+
190+
191+
if __name__ == "__main__":
192+
main()

cond_gen.py renamed to crystalformer/cli/cond_gen.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from crystalformer.src.transformer import make_transformer
1818

1919

20-
if __name__ == "__main__":
20+
def main():
2121

2222
import argparse
2323
parser = argparse.ArgumentParser(description='')
@@ -248,3 +248,7 @@
248248
data.to_csv(filename, mode='a', index=False, header=header)
249249

250250
print ("Wrote samples to %s"%filename)
251+
252+
253+
if __name__ == "__main__":
254+
main()

crystalformer/cli/dataset.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import lmdb
3+
import pickle
4+
import numpy as np
5+
from crystalformer.src.utils import GLXYZAW_from_file
6+
import warnings
7+
warnings.filterwarnings("ignore")
8+
9+
10+
def csv_to_lmdb(csv_file, lmdb_file, args):
11+
if os.path.exists(lmdb_file):
12+
os.remove(lmdb_file)
13+
print(f"Removed existing {lmdb_file}")
14+
15+
values = GLXYZAW_from_file(csv_file,
16+
atom_types=args.atom_types,
17+
wyck_types=args.wyck_types,
18+
n_max=args.n_max,
19+
num_workers=args.num_workers)
20+
keys = np.arange(len(values[0]))
21+
22+
env = lmdb.open(
23+
lmdb_file,
24+
subdir=False,
25+
readonly=False,
26+
lock=False,
27+
readahead=False,
28+
meminit=False,
29+
max_readers=1,
30+
map_size=int(100e9),
31+
)
32+
33+
with env.begin(write=True) as txn:
34+
for key, value in zip(keys, values):
35+
txn.put(str(key).encode("utf-8"), pickle.dumps(value))
36+
37+
print(f"Successfully converted {csv_file} to {lmdb_file}")
38+
39+
40+
def main():
41+
import argparse
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument('--n_max', type=int, default=21, help='The maximum number of atoms in the cell')
44+
parser.add_argument('--atom_types', type=int, default=119, help='Atom types including the padded atoms')
45+
parser.add_argument('--wyck_types', type=int, default=28, help='Number of possible multiplicites including 0')
46+
47+
parser.add_argument("--path", type=str, required=True)
48+
parser.add_argument("--num_workers", type=int, default=40)
49+
args = parser.parse_args()
50+
51+
for i in ["test", "val", "train"]:
52+
csv_to_lmdb(
53+
os.path.join(args.path, f"{i}.csv"),
54+
os.path.join(args.path, f"{i}.lmdb"),
55+
args
56+
)
57+
58+
59+
if __name__ == "__main__":
60+
main()

0 commit comments

Comments
 (0)