Skip to content

Commit db07625

Browse files
committed
update
1 parent 2ad6156 commit db07625

File tree

5 files changed

+48
-17
lines changed

5 files changed

+48
-17
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
python -m pip install --upgrade pip
2525
pip install pytest
2626
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
27-
pip install --upgrade "jax[cpu]==0.4.25"
27+
pip install -U jax
2828
pip install .
2929
- name: Test with pytest
3030
run: |

README.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ crystal space, which is crucial for data and compute efficient generative modeli
2323
- [CPU installation](#cpu-installation)
2424
- [CUDA (GPU) installation](#cuda-gpu-installation)
2525
- [install required packages](#install-required-packages)
26+
- [command line tools](#command-line-tools)
2627
- [Available Weights](#available-weights)
2728
- [How to run](#how-to-run)
2829
- [train](#train)
@@ -85,17 +86,14 @@ pip install -U "jax[cpu]"
8586

8687
### CUDA (GPU) installation
8788

88-
> [!CAUTION]
89-
> CrystalFormer requires JAX versions between v0.4.25 and v0.4.35. Please avoid using versions newer than v0.4.35 to prevent compatibility issues. If you have already installed a newer version, please uninstall it first and then install one within the required range.
90-
9189
If you intend to use CUDA (GPU) to speed up the training, it is important to install the appropriate version of `jax` and `jaxlib`. It is recommended to check the [jax docs](https://github.com/google/jax?tab=readme-ov-file#installation) for the installation guide. The basic installation command is given below:
9290

9391
```bash
9492
pip install --upgrade pip
9593

9694
# NVIDIA CUDA 12 installation
9795
# Note: wheels only available on linux.
98-
pip install --upgrade "jax[cuda12]"
96+
pip install -U "jax[cuda12]"
9997
```
10098

10199
### install required packages
@@ -104,6 +102,13 @@ pip install --upgrade "jax[cuda12]"
104102
pip install -r requirements.txt
105103
```
106104

105+
### command line tools
106+
To use the command line tools, you need to install the `crystalformer` package. You can use the following command to install the package:
107+
108+
```bash
109+
pip install .
110+
```
111+
107112
## Available Weights
108113

109114
We release the weights of the model trained on the MP-20 dataset and Alex-20 dataset. More details can be seen in the [model](./model/README.md) folder.
@@ -154,6 +159,9 @@ python ./scripts/awl2struct.py --output_path YOUR_PATH --label SPACE_GROUP --nu
154159
- `label`: the label to save the `cif` files, which is the space group number `g`
155160
- `num_io_process`: the number of processes
156161

162+
> [!IMPORTANT]
163+
> The following evaluation script requires the [`SMACT`](https://github.com/WMD-group/SMACT), [`matminer`](https://github.com/hackingmaterials/matminer), and [`matbench-genmetrics`](https://github.com/sparks-baird/matbench-genmetrics) packages. We recommend installing them in a separate environment to avoid conflicts with other packages.
164+
157165
Calculate the structure and composition validity of the generated structures:
158166

159167
```bash

requirements.txt

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
dm-haiku==0.0.11
2-
optax==0.1.8
3-
pymatgen==2024.3.1
4-
pyxtal==0.6.3
5-
SMACT==2.5.5
6-
matbench-genmetrics==0.6.1
7-
matminer @ git+https://github.com/hackingmaterials/matminer.git
1+
dm-haiku==0.1.14
2+
optax==0.2.4
3+
pyxtal==1.0.7

scripts/README.md

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- [Relaxation](#relaxation)
1010
- [Energy Above the Hull](#energy-above-the-hull)
1111
- [Embedding Visualization](#embedding-visualization)
12+
- [Stable, Unique and Novel Structures](#stable-unique-and-novel-structures)
1213
- [Structure Visualization](#structure-visualization)
1314

1415
### Transform
@@ -48,15 +49,18 @@ Note that the training, test, and generated datasets should contain the structur
4849

4950

5051
### Relaxation
51-
`matgl_relax.py` is a script to relax the generated structures using the `matgl` package. You can install the `matgl` following the instructions in the [matgl repository](https://github.com/materialsvirtuallab/matgl?tab=readme-ov-file).
52+
`mlff_relax.py` is a script to relax the generated structures using pretrained machine learning force field. Now we support the [`orb`](https://github.com/orbital-materials/orb-models), [`MACE`](https://github.com/ACEsuit/mace), [`matgl`](https://github.com/materialsvirtuallab/matgl) and [`deepmd-kit`](https://github.com/deepmodeling/deepmd-kit) models. Please install corresponding packages before running the script.
53+
5254
```bash
53-
python matgl_relax.py --restore_path RESTORE_PATH --filename FILENAME --relaxation --model_path MODEL_PATH
55+
python mlff_relax.py --restore_path RESTORE_PATH --filename FILENAME --relaxation --model orb --model_path MODEL_PATH
5456
```
5557
- `restore_path`: the path to the generated structures
5658
- `filename`: the filename of the generated structures
5759
- `relaxation`: whether to relax the structures, if not specified, the script will only predict the energy of the structures without relaxation
58-
- `model_path`: the path to the `matgl` model checkpoint
59-
60+
- `model`: the model to use for relaxation, which can be `orb`, `mace`, `matgl` or `dp`
61+
- `model_path`: the path to the machine learning force field checkpoint
62+
- `primitive`: whether to convert the structures to primitive cells, if not specified, the script will only relax the structures without converting to primitive cells. This can be used to reduce the number of atoms in the structures and speed up the relaxation process
63+
- `fixsymmetry`: whether to fix the space group symmetry of the structures in the relaxation process
6064

6165
### Energy Above the Hull
6266
`e_above_hull.py` is a script to calculate the energy above the hull of the generated structures based on the Materials Project database. To calculate the energy above the hull, the API key of the Materials Project is required, which can be obtained from the [Materials Project website](https://next-gen.materialsproject.org/). Furthermore, the `mp_api` package should be installed.
@@ -70,6 +74,18 @@ python e_above_hull.py --restore_path RESTORE_PATH --filename FILENAME --api_key
7074
- `label`: the label to save the energy above the hull file
7175
- `relaxation`: whether to calculate the energy above the hull based on the relaxed structures
7276

77+
`e_above_hull_alex.py` is a script to calculate the energy above the hull of the generated structures based on the Alexandria database. To calculate the energy above the hull, the Alexandria convex hull data is required, which can be obtained from the [Alexandria website](https://alexandria.icams.rub.de/).
78+
79+
```bash
80+
python e_above_hull_alex.py --convex_path CONVEX_PATH --restore_path RESTORE_PATH --filename FILENAME --api_key API_KEY --label LABEL --relaxation
81+
```
82+
- `convex_path`: the path to the Alexandria convex hull data
83+
- `restore_path`: the path to the structures
84+
- `filename`: the filename of the structures
85+
- `api_key`: the API key of the Materials Project
86+
- `label`: the label to save the energy above the hull file
87+
- `relaxation`: whether to calculate the energy above the hull based on the relaxed structures
88+
7389
### Embedding Visualization
7490
`plot_embeddings.py` is a script to visualize the correlation of the learned embedding vectors of different elements.
7591

@@ -79,5 +95,16 @@ python plot_embeddings.py --restore_path RESTORE_PATH
7995

8096
- `restore_path`: the path to the model checkpoint
8197

98+
### Stable, Unique and Novel Structures
99+
`check_sun_materials.py` is a script to check the stable, unique and novel structures based on the given reference dataset.
100+
101+
```bash
102+
python check_sun_materials.py --restore_path RESTORE_PATH --filename FILENAME --ref_path REF_PATH
103+
```
104+
105+
- `restore_path`: the path to the generated structures
106+
- `filename`: the filename of the generated structures
107+
- `ref_path`: the path to the reference dataset
108+
82109
### Structure Visualization
83110
`structure_visualization.ipynb` is a notebook to visualize the generated structures.

scripts/mlff_relax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def main(args):
174174

175175
import argparse
176176
parser = argparse.ArgumentParser()
177-
parser.add_argument("--model", type=str, choices=["orb", "matgl", "mace"], default="orb", help="choose the MLFF model")
177+
parser.add_argument("--model", type=str, choices=["orb", "matgl", "mace", "dp"], default="orb", help="choose the MLFF model")
178178
parser.add_argument("--device", type=str, default="cuda", help="choose the device to run the model on")
179179
parser.add_argument("--model_path", type=str, default="./data/orb-v2-20241011.ckpt", help="path to the model checkpoint")
180180
parser.add_argument("--restore_path", type=str, default="./experimental/", help="")

0 commit comments

Comments
 (0)