diff --git a/.gitignore b/.gitignore index cbc1eb7..ee7d6e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.venv *.pyc dist datasets @@ -17,4 +18,4 @@ weights *.flo *.egg-info *.npy -*.npz \ No newline at end of file +*.npz diff --git a/README.md b/README.md index f81a8c3..34684c7 100644 --- a/README.md +++ b/README.md @@ -27,15 +27,28 @@ If you find SEA-RAFT useful for your work, please consider citing our academic p year={2024} } ``` +## installation + +```bash +pip install git+https://github.com/princeton-vl/SEA-RAFT +``` + + +```bash +pip install sea-raft # if its has gotten uploaded to pypi +``` + ## Requirements Our code is developed with pytorch 2.2.0, CUDA 12.2 and python 3.10. ```Shell conda create --name SEA-RAFT python=3.10.13 conda activate SEA-RAFT -pip install -r requirements.txt +pip install . # if you have cloned it +pip install git+https://github.com/princeton-vl/SEA-RAFT # if you want to use it ``` + ## Model Zoo Google Drive: [link](https://drive.google.com/drive/folders/1YLovlvUW94vciWvTyLf-p3uWscbOQRWW?usp=sharing). diff --git a/core/__init__.py b/__init__.py similarity index 100% rename from core/__init__.py rename to __init__.py diff --git a/custom.py b/custom.py index 208e4fb..ed508cb 100644 --- a/custom.py +++ b/custom.py @@ -13,10 +13,10 @@ from config.parser import parse_args -import datasets -from raft import RAFT -from utils.flow_viz import flow_to_image -from utils.utils import load_ckpt +import sea_raft.datasets +from sea_raft.raft import RAFT +from sea_raft.utils.flow_viz import flow_to_image +from sea_raft.utils.utils import load_ckpt def create_color_bar(height, width, color_map): """ diff --git a/demo.py b/demo.py index ce753cc..39c0440 100644 --- a/demo.py +++ b/demo.py @@ -14,9 +14,9 @@ from config.parser import parse_args import datasets -from raft import RAFT -from utils.flow_viz import flow_to_image -from utils.utils import load_ckpt +from sea_raft.raft import RAFT +from sea_raft.utils.flow_viz import flow_to_image +from sea_raft.utils.utils import load_ckpt def create_color_bar(height, width, color_map): """ diff --git a/eval_ptlflow.py b/eval_ptlflow.py index 03493db..79c1bd0 100644 --- a/eval_ptlflow.py +++ b/eval_ptlflow.py @@ -8,14 +8,15 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data - -import datasets -from raft import RAFT from tqdm import tqdm -from utils import flow_viz -from utils import frame_utils -from utils.utils import resize_data, load_ckpt +import sea_raft.datasets +from sea_raft.raft import RAFT + + +from sea_raft.utils import flow_viz +from sea_raft.utils import frame_utils +from sea_raft.utils.utils import resize_data, load_ckpt import ptlflow from ptlflow.utils import flow_utils diff --git a/evaluate.py b/evaluate.py index b1f19db..2d5d672 100644 --- a/evaluate.py +++ b/evaluate.py @@ -10,10 +10,10 @@ from config.parser import parse_args -import datasets -from raft import RAFT +import sea_raft.datasets +from sea_raft.raft import RAFT from tqdm import tqdm -from utils.utils import resize_data, load_ckpt +from sea_raft.utils.utils import resize_data, load_ckpt def forward_flow(args, model, image1, image2): output = model(image1, image2, iters=args.iters, test_mode=True) @@ -169,6 +169,7 @@ def eval(args): def main(): parser = argparse.ArgumentParser() + parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) parser.add_argument('--model', help='checkpoint path', required=True, type=str) args = parse_args(parser) diff --git a/profile_ptlflow.py b/profile_ptlflow.py index ab0cd76..e1a147b 100644 --- a/profile_ptlflow.py +++ b/profile_ptlflow.py @@ -8,15 +8,14 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data - -import datasets -from raft import RAFT from tqdm import tqdm -from utils import flow_viz -from utils import frame_utils -from utils.profile import profile_model -from utils.utils import resize_data, load_ckpt +import sea_raft.datasets +from sea_raft.raft import RAFT +from sea_raft.utils import flow_viz +from sea_raft.utils import frame_utils +from sea_raft.utils.profile import profile_model +from sea_raft.utils.utils import resize_data, load_ckpt import ptlflow from ptlflow.utils import flow_utils diff --git a/profiler.py b/profiler.py index 838491a..da99247 100644 --- a/profiler.py +++ b/profiler.py @@ -3,7 +3,7 @@ import argparse import torch from config.parser import parse_args -from raft import RAFT +from sea_raft.raft import RAFT def main(): parser = argparse.ArgumentParser() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..06727e9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "sea-raft" +version = "0.0.1" +description = """ +We introduce SEA-RAFT, a more simple, efficient, and accurate RAFT for optical flow. Compared with RAFT, SEA-RAFT is trained with a new loss (mixture of Laplace). It directly regresses an initial flow for faster convergence in iterative refinements and introduces rigid-motion pre-training to improve generalization. SEA-RAFT achieves state-of-the-art accuracy on the Spring benchmark with a 3.69 endpoint-error (EPE) and a 0.36 1-pixel outlier rate (1px), representing 22.9% and 17.8% error reduction from best-published results. In addition, SEA-RAFT obtains the best cross-dataset generalization on KITTI and Spring. With its high efficiency, SEA-RAFT operates at least 2.3x faster than existing methods while maintaining competitive performance. +""" +authors = [ + {name = "Wang, Yihan"}, + {name = "Lipson, Lahav"}, + {name = "Deng, Jia"} +] +readme = "README.md" +requires-python = ">= 3.10.13" +dependencies= [ + "torch", + "torchvision", + "torchaudio", + "numpy", + "matplotlib", + "scipy", + "opencv-python", + "tensorboard", + "h5py", + "tqdm", + "einops", + "huggingface-hub" +] + +[project.optional-dependencies] +profiler =[ + "ptlflow" +] +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + + + +[tool.setuptools.packages.find] +include = ["sea_raft*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 481f03d..fac5395 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ opencv-python tensorboard h5py tqdm -einops \ No newline at end of file +einops +huggingface-hub \ No newline at end of file diff --git a/sea_raft/__init__.py b/sea_raft/__init__.py new file mode 100644 index 0000000..4857edf --- /dev/null +++ b/sea_raft/__init__.py @@ -0,0 +1 @@ +import sea_raft.utils \ No newline at end of file diff --git a/core/corr.py b/sea_raft/corr.py similarity index 99% rename from core/corr.py rename to sea_raft/corr.py index 3183846..1833a01 100644 --- a/core/corr.py +++ b/sea_raft/corr.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from utils.utils import coords_grid, bilinear_sampler +from sea_raft.utils.utils import coords_grid, bilinear_sampler try: import alt_cuda_corr diff --git a/core/datasets.py b/sea_raft/datasets.py similarity index 99% rename from core/datasets.py rename to sea_raft/datasets.py index b5d9e90..f0da965 100644 --- a/core/datasets.py +++ b/sea_raft/datasets.py @@ -12,9 +12,9 @@ from tqdm import tqdm from glob import glob import os.path as osp -from utils import frame_utils -from utils.augmentor import FlowAugmentor, SparseFlowAugmentor -from utils.utils import induced_flow, check_cycle_consistency +from sea_raft.utils import frame_utils +from sea_raft.utils.augmentor import FlowAugmentor, SparseFlowAugmentor +from sea_raft.utils.utils import induced_flow, check_cycle_consistency from ddp_utils import * class FlowDataset(data.Dataset): diff --git a/core/extractor.py b/sea_raft/extractor.py similarity index 98% rename from core/extractor.py rename to sea_raft/extractor.py index c838330..6fa0464 100644 --- a/core/extractor.py +++ b/sea_raft/extractor.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from layer import BasicBlock, conv1x1, conv3x3 +from sea_raft.layer import BasicBlock, conv1x1, conv3x3 class ResNetFPN(nn.Module): """ diff --git a/core/layer.py b/sea_raft/layer.py similarity index 100% rename from core/layer.py rename to sea_raft/layer.py diff --git a/core/loss.py b/sea_raft/loss.py similarity index 100% rename from core/loss.py rename to sea_raft/loss.py diff --git a/core/raft.py b/sea_raft/raft.py similarity index 96% rename from core/raft.py rename to sea_raft/raft.py index 4c8d92f..e9b1fd3 100644 --- a/core/raft.py +++ b/sea_raft/raft.py @@ -4,11 +4,11 @@ import torch.nn as nn import torch.nn.functional as F -from update import BasicUpdateBlock -from corr import CorrBlock -from utils.utils import coords_grid, InputPadder -from extractor import ResNetFPN -from layer import conv1x1, conv3x3 +from sea_raft.update import BasicUpdateBlock +from sea_raft.corr import CorrBlock +from sea_raft.utils.utils import coords_grid, InputPadder +from sea_raft.extractor import ResNetFPN +from sea_raft.layer import conv1x1, conv3x3 from huggingface_hub import PyTorchModelHubMixin diff --git a/core/update.py b/sea_raft/update.py similarity index 97% rename from core/update.py rename to sea_raft/update.py index 08e00fd..419eb96 100644 --- a/core/update.py +++ b/sea_raft/update.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from layer import ConvNextBlock +from sea_raft.layer import ConvNextBlock class FlowHead(nn.Module): def __init__(self, input_dim=128, hidden_dim=256, output_dim=4): diff --git a/core/utils/__init__.py b/sea_raft/utils/__init__.py similarity index 100% rename from core/utils/__init__.py rename to sea_raft/utils/__init__.py diff --git a/core/utils/augmentor.py b/sea_raft/utils/augmentor.py similarity index 100% rename from core/utils/augmentor.py rename to sea_raft/utils/augmentor.py diff --git a/core/utils/flow_transforms.py b/sea_raft/utils/flow_transforms.py similarity index 100% rename from core/utils/flow_transforms.py rename to sea_raft/utils/flow_transforms.py diff --git a/core/utils/flow_viz.py b/sea_raft/utils/flow_viz.py similarity index 100% rename from core/utils/flow_viz.py rename to sea_raft/utils/flow_viz.py diff --git a/core/utils/frame_utils.py b/sea_raft/utils/frame_utils.py similarity index 100% rename from core/utils/frame_utils.py rename to sea_raft/utils/frame_utils.py diff --git a/core/utils/utils.py b/sea_raft/utils/utils.py similarity index 100% rename from core/utils/utils.py rename to sea_raft/utils/utils.py diff --git a/submission.py b/submission.py index b0f45e8..2016cd2 100644 --- a/submission.py +++ b/submission.py @@ -14,13 +14,13 @@ from config.parser import parse_args -import datasets -from raft import RAFT +import sea_raft.datasets +from sea_raft.raft import RAFT from tqdm import tqdm -from utils.flow_viz import flow_to_image -from utils import frame_utils -from utils.utils import load_ckpt, InputPadder +from sea_raft.utils.flow_viz import flow_to_image +from sea_raft.utils import frame_utils +from sea_raft.utils.utils import load_ckpt, InputPadder def forward_flow(args, model, image1, image2): output = model(image1, image2, iters=args.iters, test_mode=True) diff --git a/train.py b/train.py index 7fa5504..eef44be 100644 --- a/train.py +++ b/train.py @@ -9,10 +9,10 @@ import torch import torch.optim as optim -from raft import RAFT -from datasets import fetch_dataloader -from utils.utils import load_ckpt -from loss import sequence_loss +from sea_raft.raft import RAFT +from sea_raft.datasets import fetch_dataloader +from sea_raft.utils.utils import load_ckpt +from sea_raft.loss import sequence_loss from ddp_utils import * os.system("export KMP_INIT_AT_FORK=FALSE")