Skip to content

Commit

Permalink
Add option to initialize flow pred.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristophReich1996 committed Jul 11, 2024
1 parent 8f9d810 commit ac5eae6
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .._utils import handle_legacy_interface
from ._utils import grid_sample, make_coords_grid, upsample_flow


__all__ = (
"RAFT",
"raft_large",
Expand Down Expand Up @@ -120,7 +119,7 @@ class FeatureEncoder(nn.Module):
"""

def __init__(
self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
):
super().__init__()

Expand Down Expand Up @@ -149,7 +148,7 @@ def __init__(

num_downsamples = len(list(filter(lambda s: s == 2, strides)))
self.output_dim = layers[-1]
self.downsample_factor = 2**num_downsamples
self.downsample_factor = 2 ** num_downsamples

def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
Expand Down Expand Up @@ -481,13 +480,16 @@ def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block
if not hasattr(self.update_block, "hidden_state_size"):
raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.")

def forward(self, image1, image2, num_flow_updates: int = 12):
def forward(self, image1, image2, num_flow_updates: int = 12, flow_init: Optional[Tensor] = None):

batch_size, _, h, w = image1.shape
if (h, w) != image2.shape[-2:]:
raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
if not (h % 8 == 0) and (w % 8 == 0):
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
if (flow_init is not None) and ((batch_size, 2, h // 8, w // 8) != flow_init.shape):
raise ValueError(
f"initial optical flow must have the shape ({batch_size}, 2, {h // 8}, {w // 8}), instead got {flow_init.shape}")

fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
Expand Down Expand Up @@ -516,6 +518,9 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)

if flow_init is not None:
coords1 = coords1 + flow_init

flow_predictions = []
for _ in range(num_flow_updates):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
Expand Down Expand Up @@ -754,33 +759,33 @@ class Raft_Small_Weights(WeightsEnum):


def _raft(
*,
weights=None,
progress=False,
# Feature encoder
feature_encoder_layers,
feature_encoder_block,
feature_encoder_norm_layer,
# Context encoder
context_encoder_layers,
context_encoder_block,
context_encoder_norm_layer,
# Correlation block
corr_block_num_levels,
corr_block_radius,
# Motion encoder
motion_encoder_corr_layers,
motion_encoder_flow_layers,
motion_encoder_out_channels,
# Recurrent block
recurrent_block_hidden_state_size,
recurrent_block_kernel_size,
recurrent_block_padding,
# Flow Head
flow_head_hidden_size,
# Mask predictor
use_mask_predictor,
**kwargs,
*,
weights=None,
progress=False,
# Feature encoder
feature_encoder_layers,
feature_encoder_block,
feature_encoder_norm_layer,
# Context encoder
context_encoder_layers,
context_encoder_block,
context_encoder_norm_layer,
# Correlation block
corr_block_num_levels,
corr_block_radius,
# Motion encoder
motion_encoder_corr_layers,
motion_encoder_flow_layers,
motion_encoder_out_channels,
# Recurrent block
recurrent_block_hidden_state_size,
recurrent_block_kernel_size,
recurrent_block_padding,
# Flow Head
flow_head_hidden_size,
# Mask predictor
use_mask_predictor,
**kwargs,
):
feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer
Expand Down

0 comments on commit ac5eae6

Please sign in to comment.