Skip to content

[ NeurIPS 2024 ] The official PyTorch implementation for Learning Truncated Causal History Model for Video Restoration.

License

Notifications You must be signed in to change notification settings

Ascend-Research/Turtle

Repository files navigation

PWC PWC PWC PWC PWC PWC

Lego Turtle Turtle: Learning Truncated Causal History Model for Video Restoration [NeurIPS'2024]

📄 arxiv | 🌐 Website

The official PyTorch implementation for Learning Truncated Causal History Model for Video Restoration, accepted to NeurIPS 2024.

  • Turtle achieves state-of-the-art results on multiple video restoration benchmarks, offering superior computational efficiency and enhanced restoration quality 🔥🔥🔥.
  • 🛠️💡Model Forge: Easily design your own architecture by modifying the option file.
    • You have the flexibility to choose from various types of layers—such as channel attention, simple channel attention, CHM, FHR, or custom blocks—as well as different types of feed-forward layers.
    • This setup allows you to create custom networks and experiment with layer and feed-forward configurations to suit your needs.
  • If you like this project, please give us a ⭐ on Github!🚀

Restored Video 1 Restored Video 2

Restored Video 3 Restored Video 4

🔥 📰 News 🔥

  • Oct. 10, 2024: The paper is now available on arxiv along with the code and pretrained models.
  • Sept 25, 2024: Turtle is accepted to NeurIPS'2024.

Table of Contents

  1. Installation
  2. Trained Models
  3. Dataset Preparation
  4. Training
  5. Evaluation
  6. Model Complexity and Inference Speed
  7. Acknowledgments
  8. Citation

Installation

This implementation is based on BasicSR which is an open-source toolbox for image/video restoration tasks.

python 3.9.5
pytorch 1.11.0
cuda 11.3
pip install -r requirements.txt
python setup.py develop --no_cuda_ext

Trained Models

You can download our trained models from Google Drive: Trained Models

1. Dataset Preparation

To obtain the datasets, follow the official instructions provided by each dataset's provider and download them into the dataset folder. You can download the datasets for each of the task from the following links (official sources reported by their respective authors).

  1. Desnowing: RSVD
  2. Raindrops and Rainstreaks Removal: VRDS
  3. Night Deraining: NightRain
  4. Synthetic Deblurring: GoPro
  5. Real-World Deblurring: BSD3ms-24ms
  6. Denoising: DAVIS | Set8
  7. Real-World Super Resolution: MVSR

The directory structure, including the ground truth ('gt') for reference frames and 'blur' for degraded images, should be organized as follows:

./datasets/
└── Dataset_name/
    ├── train/
    └── test/
        ├── blur
           ├── video_1
           │   ├── Fame1
           │   ....
           └── video_n
           │   ├── Fame1
           │   ....
        └── gt
           ├── video_1
           │   ├── Fame1
           │   ....
           └── video_n
           │   ├── Fame1
           │   ....

2. Training

To train the model, make sure you select the appropriate data loader in the train.py. There are two options as follows.

  1. For deblurring, denoising, deraining, etc. keep the following import line, and comment the superresolution one. from basicsr.data.video_image_dataset import VideoImageDataset

  2. For superresolution, keep the following import line, and comment the previous one. from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset

python -m torch.distributed.launch --nproc_per_node=8 --master_port=8080 basicsr/train.py -opt /options/option_file_name.yml --launcher pytorch

3. Evaluation

The pretrained models can be downloaded from the GDrive link.

3.1 Testing the model

To evaluate the pre-trained model use this command:

python inference.py

Adjust the function parameters in the Python file according to each task requirements:

  1. config: Specify the path to the option file.
  2. model_path: Provide the location of pre-trained model.
  3. dataset_name: Select the dataset you are using ("RSVD", "GoPro", "SR", "NightRain", "DVD", "Set8").
  4. task_name: Choose the restoration task ("Desnowing", "Deblurring", "SR", "Deraining", "Denoising").
  5. model_type: Indicate the model type ("t0", "t1", "SR").
  6. save_image: Set to True if you want to save the output images; provide the output path in image_out_path.
  7. do_patches: Enable if processing images in patches; adjust tile and tile_overlap as needed, default values are 320 and 128.
  8. y_channel_PSNR: Enable if need to calculate PSNR/SSIM in Y Channel, default is set to False.

3.2 Running Turtle on Custom Videos:

This pipeline processes a video by extracting frames and running a pre-trained model for tasks like desnowing:

Step 1: Extract Frames from Video

  1. Edit video_to_frames.py:

    • Set the video_path to your input video file.
    • Set the output_folder to save extracted frames.
  2. Run the script:

    python video_to_frames.py

Step 2: Run Model Inference

  1. Edit inference_no_ground_truth.py:

    • Set paths for config, model_path, data_dir (extracted frames), and image_out_path (output frames).
  2. Run the script:

    python inference_no_ground_truth.py

4. Model complexity and inference speed

  • To get the parameter count, MAC, and inference speed use this command:
python basicsr/models/archs/turtle_arch.py

5. Acknowledgments

This codebase borrows from the following BasicSR and ShiftNet repositories.

6. Citation

If you find our work useful, please consider citing our paper in your research.

@inproceedings{ghasemabadilearning,
  title={Learning Truncated Causal History Model for Video Restoration},
  author={Ghasemabadi, Amirhosein and Janjua, Muhammad Kamran and Salameh, Mohammad and Niu, Di},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}