Skip to content

Official implementation of Scalable Transformer for PDE surrogate modelling

License

Notifications You must be signed in to change notification settings

BaratiLab/FactFormer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FactFormer

Official implementation of Scalable Transformer for PDE surrogate modeling (paper) .

Getting started

The code is tested under PyTorch 1.8.2 and CUDA 11, later versions should also work. Other packages needed are Numpy/einops/Matplotlib/tqdm

Use a layer of 3D factorized attention:

import torch
from libs.factorization_module import FABlock3D

fa_layer = FABlock3D(dim,                   # input dimension
                     dim_head,              # dimension in each attention head, will be expanded by the kernel_multiplier when computing kernel: d = dim_head * kernel_multiplier
                     latent_dim,            # the output dimension of the projection operator
                     heads,                 # attention heads
                     dim_out,               # output dimension
                     kernel_multiplier,     # use more function bases to computer kernel: k(x_i, x_j)=\sum_{c}^dq_c(x_i)k_c(x_j)    
                     use_rope,              # use rotary positional encoding or not, by default True
                     scaling_factor         # use scaling factor to modulate the kernel, an example is 1/ sqrt(d) like scaled-dot product attention, by default is: 1
                    )
# random input
z = torch.randn((4, 64, 64, 64, dim))
# axial coords
pos_x = torch.linspace(0, 1, 64).unsqueeze(-1)       # leave a channel  dimension   
pos_y = torch.linspace(0, 1, 64).unsqueeze(-1)
pos_z = torch.linspace(0, 1, 64).unsqueeze(-1)

z = fa_layer(z, [pos_x, pos_y, pos_z])

For running experiments on the problems discussed in the paper, please refer to the examples directory.

For example, training a model for Darcy flow:

python darcy2d_fact.py --config darcy2d_fact.yml

Dataset

We provided our generated dataset at:

The Darcy dataset can be obtained from FNO's repo .

Please refer to the scripts under dataset to generate data or customize the dataset.

Pretrained checkpoints

We provided the trained models used in the paper and their logs in the below links. The code in the current repo is reorganized and cleaned up a bit for better clarity, the checkpoints is trained with scripts attached with them in these links.

Problem link
KM2d link
Darcy2d link
Smoke3d link
Turb3d link

Citation

If you find this repo helpful, kindly consider citing it with

@misc{li2023scalable,
      title={Scalable Transformer for PDE Surrogate Modeling}, 
      author={Zijie Li and Dule Shu and Amir Barati Farimani},
      year={2023},
      eprint={2305.17560},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Official implementation of Scalable Transformer for PDE surrogate modelling

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published