A transformer encoder based neutrino energy estimator. This is a highly flexible frame work which allows for easy modification of the model, loss function, and data loader. Currently, the code supports CSV file with array stored as a comma separated string. We provide example scripts for NOvA and DUNE experiments.
- Transformer encoder based model.
- Customizable loss function (you can design very very complex loss functions).
- Customizable model.
- Customizable optimizer.
It is recommended to use a venv/container to run the code.
- Python 3.10 was used for development.
- PyTorch 2.3.0 was used for development. Pytorch 2.3.0 supports both CUDA and Apple MPS. By default, the code will use CUDA if it is available. Otherwise, it will use Apple MPS.
- Pandas (default dataframe library)
- Polars (an optional faster dataframe library)
- Numpy
- Matplotlib >= 3.5
- make sure you have conda installed in your computer. See https://docs.anaconda.com/free/miniconda/miniconda-install/ for more details.
- create a new conda environment by running
conda create --name new_env_name
- activate the new environment by running
conda activate new_env_name
- install the required packages by running
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
, see https://pytorch.org/get-started/locally/ for more details. - install the other required packages by running
conda install scipy pandas=2.0 numpy matplotlib
PYTHONPATH should be set to the top level directory of the repository.
For example, if this repo is cloned to /home/user/transformer_EE
, then the following should be added to your .bashrc
file:
export PYTHONPATH=/home/user/transformer_EE:$PYTHONPATH
Test that the PYTHONPATH
is set correctly by running the following command:
echo $PYTHONPATH
The commandline should return:
/home/user/transformer_EE:
There is an example script for the NC dataset. To run the code, run the following command:
python3 train_script.py
The config file is a json file. The default config file is located at transformer_ee/config. There are two ways to configure the code:
- Edit the config file directly.
- Modify the dictionary in the train_script.py file. For example, to select the model, add the following line:
input_d["model"]["name"] = "Transformer_EE_MV"
in the train_script.py file.
Transformer_EE supports WandB logging. To enable WandB logging, user should install WandB by running pip install wandb
and then run the following command: wandb init
and type in the API key.
A minimal example of using WandB logging is provided in the train_script.py file.