Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.venv
*.pyc
dist
datasets
Expand All @@ -17,4 +18,4 @@ weights
*.flo
*.egg-info
*.npy
*.npz
*.npz
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,28 @@ If you find SEA-RAFT useful for your work, please consider citing our academic p
year={2024}
}
```
## installation

```bash
pip install git+https://github.com/princeton-vl/SEA-RAFT
```


```bash
pip install sea-raft # if its has gotten uploaded to pypi
```


## Requirements
Our code is developed with pytorch 2.2.0, CUDA 12.2 and python 3.10.
```Shell
conda create --name SEA-RAFT python=3.10.13
conda activate SEA-RAFT
pip install -r requirements.txt
pip install . # if you have cloned it
pip install git+https://github.com/princeton-vl/SEA-RAFT # if you want to use it
```


## Model Zoo

Google Drive: [link](https://drive.google.com/drive/folders/1YLovlvUW94vciWvTyLf-p3uWscbOQRWW?usp=sharing).
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

from config.parser import parse_args

import datasets
from raft import RAFT
from utils.flow_viz import flow_to_image
from utils.utils import load_ckpt
import sea_raft.datasets
from sea_raft.raft import RAFT
from sea_raft.utils.flow_viz import flow_to_image
from sea_raft.utils.utils import load_ckpt

def create_color_bar(height, width, color_map):
"""
Expand Down
6 changes: 3 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from config.parser import parse_args

import datasets
from raft import RAFT
from utils.flow_viz import flow_to_image
from utils.utils import load_ckpt
from sea_raft.raft import RAFT
from sea_raft.utils.flow_viz import flow_to_image
from sea_raft.utils.utils import load_ckpt

def create_color_bar(height, width, color_map):
"""
Expand Down
13 changes: 7 additions & 6 deletions eval_ptlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import datasets
from raft import RAFT
from tqdm import tqdm

from utils import flow_viz
from utils import frame_utils
from utils.utils import resize_data, load_ckpt
import sea_raft.datasets
from sea_raft.raft import RAFT


from sea_raft.utils import flow_viz
from sea_raft.utils import frame_utils
from sea_raft.utils.utils import resize_data, load_ckpt

import ptlflow
from ptlflow.utils import flow_utils
Expand Down
7 changes: 4 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

from config.parser import parse_args

import datasets
from raft import RAFT
import sea_raft.datasets
from sea_raft.raft import RAFT
from tqdm import tqdm
from utils.utils import resize_data, load_ckpt
from sea_raft.utils.utils import resize_data, load_ckpt

def forward_flow(args, model, image1, image2):
output = model(image1, image2, iters=args.iters, test_mode=True)
Expand Down Expand Up @@ -169,6 +169,7 @@ def eval(args):

def main():
parser = argparse.ArgumentParser()

parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str)
parser.add_argument('--model', help='checkpoint path', required=True, type=str)
args = parse_args(parser)
Expand Down
13 changes: 6 additions & 7 deletions profile_ptlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import datasets
from raft import RAFT
from tqdm import tqdm

from utils import flow_viz
from utils import frame_utils
from utils.profile import profile_model
from utils.utils import resize_data, load_ckpt
import sea_raft.datasets
from sea_raft.raft import RAFT
from sea_raft.utils import flow_viz
from sea_raft.utils import frame_utils
from sea_raft.utils.profile import profile_model
from sea_raft.utils.utils import resize_data, load_ckpt

import ptlflow
from ptlflow.utils import flow_utils
Expand Down
2 changes: 1 addition & 1 deletion profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import torch
from config.parser import parse_args
from raft import RAFT
from sea_raft.raft import RAFT

def main():
parser = argparse.ArgumentParser()
Expand Down
40 changes: 40 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[project]
name = "sea-raft"
version = "0.0.1"
description = """
We introduce SEA-RAFT, a more simple, efficient, and accurate RAFT for optical flow. Compared with RAFT, SEA-RAFT is trained with a new loss (mixture of Laplace). It directly regresses an initial flow for faster convergence in iterative refinements and introduces rigid-motion pre-training to improve generalization. SEA-RAFT achieves state-of-the-art accuracy on the Spring benchmark with a 3.69 endpoint-error (EPE) and a 0.36 1-pixel outlier rate (1px), representing 22.9% and 17.8% error reduction from best-published results. In addition, SEA-RAFT obtains the best cross-dataset generalization on KITTI and Spring. With its high efficiency, SEA-RAFT operates at least 2.3x faster than existing methods while maintaining competitive performance.
"""
authors = [
{name = "Wang, Yihan"},
{name = "Lipson, Lahav"},
{name = "Deng, Jia"}
]
readme = "README.md"
requires-python = ">= 3.10.13"
dependencies= [
"torch",
"torchvision",
"torchaudio",
"numpy",
"matplotlib",
"scipy",
"opencv-python",
"tensorboard",
"h5py",
"tqdm",
"einops",
"huggingface-hub"
]

[project.optional-dependencies]
profiler =[
"ptlflow"
]
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"



[tool.setuptools.packages.find]
include = ["sea_raft*"]
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ opencv-python
tensorboard
h5py
tqdm
einops
einops
huggingface-hub
1 change: 1 addition & 0 deletions sea_raft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import sea_raft.utils
2 changes: 1 addition & 1 deletion core/corr.py → sea_raft/corr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import coords_grid, bilinear_sampler
from sea_raft.utils.utils import coords_grid, bilinear_sampler

try:
import alt_cuda_corr
Expand Down
6 changes: 3 additions & 3 deletions core/datasets.py → sea_raft/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from tqdm import tqdm
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
from utils.utils import induced_flow, check_cycle_consistency
from sea_raft.utils import frame_utils
from sea_raft.utils.augmentor import FlowAugmentor, SparseFlowAugmentor
from sea_raft.utils.utils import induced_flow, check_cycle_consistency
from ddp_utils import *

class FlowDataset(data.Dataset):
Expand Down
2 changes: 1 addition & 1 deletion core/extractor.py → sea_raft/extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import BasicBlock, conv1x1, conv3x3
from sea_raft.layer import BasicBlock, conv1x1, conv3x3

class ResNetFPN(nn.Module):
"""
Expand Down
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions core/raft.py → sea_raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch.nn as nn
import torch.nn.functional as F

from update import BasicUpdateBlock
from corr import CorrBlock
from utils.utils import coords_grid, InputPadder
from extractor import ResNetFPN
from layer import conv1x1, conv3x3
from sea_raft.update import BasicUpdateBlock
from sea_raft.corr import CorrBlock
from sea_raft.utils.utils import coords_grid, InputPadder
from sea_raft.extractor import ResNetFPN
from sea_raft.layer import conv1x1, conv3x3

from huggingface_hub import PyTorchModelHubMixin

Expand Down
2 changes: 1 addition & 1 deletion core/update.py → sea_raft/update.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from layer import ConvNextBlock
from sea_raft.layer import ConvNextBlock

class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=4):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from config.parser import parse_args

import datasets
from raft import RAFT
import sea_raft.datasets
from sea_raft.raft import RAFT
from tqdm import tqdm

from utils.flow_viz import flow_to_image
from utils import frame_utils
from utils.utils import load_ckpt, InputPadder
from sea_raft.utils.flow_viz import flow_to_image
from sea_raft.utils import frame_utils
from sea_raft.utils.utils import load_ckpt, InputPadder

def forward_flow(args, model, image1, image2):
output = model(image1, image2, iters=args.iters, test_mode=True)
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import torch
import torch.optim as optim

from raft import RAFT
from datasets import fetch_dataloader
from utils.utils import load_ckpt
from loss import sequence_loss
from sea_raft.raft import RAFT
from sea_raft.datasets import fetch_dataloader
from sea_raft.utils.utils import load_ckpt
from sea_raft.loss import sequence_loss
from ddp_utils import *

os.system("export KMP_INIT_AT_FORK=FALSE")
Expand Down