Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a setup.py file for a local raft package #112

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed core/utils/__init__.py
Empty file.
10 changes: 3 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import sys
sys.path.append('core')

import argparse
import os
import cv2
Expand All @@ -10,9 +7,8 @@
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder

from raft.utils import flow_viz
from raft.utils.utils import InputPadder


DEVICE = 'cuda'
Expand Down Expand Up @@ -41,7 +37,7 @@ def viz(img, flo):

def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model.load_state_dict(torch.load(args.model, map_location=DEVICE))

model = model.module
model.to(DEVICE)
Expand Down
11 changes: 4 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import sys
sys.path.append('core')

from PIL import Image
import argparse
import os
Expand All @@ -10,12 +7,12 @@
import torch.nn.functional as F
import matplotlib.pyplot as plt

import datasets
from utils import flow_viz
from utils import frame_utils
import raft.datasets as datasets
from raft.utils import flow_viz
from raft.utils import frame_utils

from raft import RAFT
from utils.utils import InputPadder, forward_interpolate
from raft.utils.utils import InputPadder, forward_interpolate


@torch.no_grad()
Expand Down
1 change: 1 addition & 0 deletions raft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from raft.raft import RAFT
2 changes: 1 addition & 1 deletion core/corr.py → raft/corr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from utils.utils import bilinear_sampler, coords_grid
from raft.utils.utils import bilinear_sampler, coords_grid

try:
import alt_cuda_corr
Expand Down
4 changes: 2 additions & 2 deletions core/datasets.py → raft/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from glob import glob
import os.path as osp

from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
from raft.utils import frame_utils
from raft.utils.augmentor import FlowAugmentor, SparseFlowAugmentor


class FlowDataset(data.Dataset):
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions core/raft.py → raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch.nn as nn
import torch.nn.functional as F

from update import BasicUpdateBlock, SmallUpdateBlock
from extractor import BasicEncoder, SmallEncoder
from corr import CorrBlock, AlternateCorrBlock
from utils.utils import bilinear_sampler, coords_grid, upflow8
from raft.update import BasicUpdateBlock, SmallUpdateBlock
from raft.extractor import BasicEncoder, SmallEncoder
from raft.corr import CorrBlock, AlternateCorrBlock
from raft.utils.utils import bilinear_sampler, coords_grid, upflow8

try:
autocast = torch.cuda.amp.autocast
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from setuptools import setup, find_packages

setup(
name='raft',
version='1.0.0',
packages=find_packages(),
install_requires=[
'torch',
'torchvision',
'opencv-python',
'matplotlib',
'tensorboard',
'scipy'
]
)

6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from __future__ import print_function, division
import sys
sys.path.append('core')

import argparse
import os
Expand All @@ -16,8 +14,8 @@

from torch.utils.data import DataLoader
from raft import RAFT
import raft.datasets as datasets
import evaluate
import datasets

from torch.utils.tensorboard import SummaryWriter

Expand Down Expand Up @@ -244,4 +242,4 @@ def train(args):
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')

train(args)
train(args)