Skip to content

Commit

Permalink
Merge pull request #16 from rafaelperez/12-update-practical-rife-weig…
Browse files Browse the repository at this point in the history
…hts-to-version-415

Update practical rife weights to version 415
  • Loading branch information
rafaelperez authored Mar 17, 2024
2 parents 4d8f00f + ef8bded commit 2ab34f4
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 53 deletions.
81 changes: 77 additions & 4 deletions model/IFNet_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def __init__(self):
self.encode = Head()

def forward(
self, x, timestep: float = 0.5, scale_list: List[float] = (8.0, 4.0, 2.0, 1.0)
self,
x,
timestep: float = 0.5,
scale_list: List[float] = (8.0, 4.0, 2.0, 1.0),
ensemble: bool = False,
):
channel = x.shape[1] // 2
img0 = x[:, :channel]
Expand All @@ -207,6 +211,15 @@ def forward(
None,
scale=scale_list[0],
)
if ensemble:
f_, m_ = self.block0(
torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1),
None,
scale=scale_list[0],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2

warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])

Expand All @@ -220,7 +233,27 @@ def forward(
flow,
scale=scale_list[1],
)
mask = m0
if ensemble:
f_, m_ = self.block1(
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[1],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0

flow = flow + fd
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
Expand All @@ -235,7 +268,27 @@ def forward(
flow,
scale=scale_list[2],
)
mask = m0
if ensemble:
f_, m_ = self.block2(
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[2],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0

flow = flow + fd
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
Expand All @@ -250,7 +303,27 @@ def forward(
flow,
scale=scale_list[3],
)
mask = m0
if ensemble:
f_, m_ = self.block3(
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=scale_list[3],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0

flow = flow + fd
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
Expand Down
File renamed without changes.
Binary file added model/flownet_v4.15.pkl
Binary file not shown.
Binary file modified nuke/Cattery/RIFE/RIFE.cat
Binary file not shown.
2 changes: 1 addition & 1 deletion nuke/Cattery/RIFE/RIFE.gizmo
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Gizmo {
addUserKnob {41 filter l Filter t "Filtering method for the STMap distortion.\n\n<b>Note:</b> Only effective if <b>Channels</b> is set to <b>all</b>." T C_STMap1.filter}
addUserKnob {6 skipKeyframes l "Process only intermediate frames" t "When processing keyframes (e.g. 11, 12, 13), RIFE can introduce slight distortion or filtering. \n\nThis option processes only intermediate frames (e.g. 11.5, 12.5) while skipping the keyframes. \n\nThis is useful for cases where the result needs to match the original unmodified frames exactly." +STARTLINE}
addUserKnob {20 infoTab l Info}
addUserKnob {26 toolName l "" +STARTLINE T "<font size='7'>RIFE</font> v1.0.2 | Released 2024-03-01"}
addUserKnob {26 toolName l "" +STARTLINE T "<font size='7'>RIFE</font> v1.1.1 | Released 2024-03-17"}
addUserKnob {26 ""}
addUserKnob {26 authorName l "" +STARTLINE T "Rafael Silva"}
addUserKnob {26 authorMail l "" +STARTLINE T "<a href=\"mailto:[email protected]\"><span style=\"color:#C8C8C8;\">[email protected]</a>"}
Expand Down
Binary file modified nuke/Cattery/RIFE/RIFE.pt
Binary file not shown.
5 changes: 3 additions & 2 deletions nuke/RIFE_CatFileCreator.nk
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
CatFileCreator {
torchScriptFile RIFE.pt
catFile RIFE.cat
torchScriptFile "\[python \{nuke.script_directory()\}]/Cattery/RIFE/RIFE.pt"
catFile "\[python \{nuke.script_directory()\}]/Cattery/RIFE/RIFE.cat"
channelsIn rgba.red,rgba.green,rgba.blue,rgba.alpha,forward.u,forward.v,backward.u,backward.v
channelsOut rgba.red,rgba.green,rgba.blue,rgba.alpha,depth.Z
modelId RIFEv4
Expand All @@ -13,4 +13,5 @@ CatFileCreator {
addUserKnob {7 scale R 0 8}
scale 1
addUserKnob {6 optical_flow +STARTLINE}
addUserKnob {6 ensemble +STARTLINE}
}
125 changes: 79 additions & 46 deletions nuke_rife.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
PATH = "model/flownet.pkl"
PATH = "model/flownet_v4.14.pkl"
TORCHSCRIPT_MODEL = "./nuke/Cattery/RIFE/RIFE.pt"


def load_flownet():
Expand All @@ -20,53 +21,85 @@ def convert(param):
return flownet


def trace_rife():
class FlowNetNuke(torch.nn.Module):
def __init__(
self, timestep: float = 0.5, scale: float = 1.0, optical_flow: int = 0
):
super().__init__()
self.optical_flow = optical_flow
self.timestep = timestep
self.scale = scale
self.flownet = load_flownet()
self.flownet_half = load_flownet().half()

def forward(self, x):
b, c, h, w = x.shape
dtype = x.dtype

timestep = self.timestep
scale = (
self.scale if self.scale in [0.125, 0.25, 0.5, 1.0, 2.0, 4.0] else 1.0
)
device = torch.device("cuda") if x.is_cuda else torch.device("cpu")

# Padding
padding_factor = max(128, int(128 / scale))
pad_h = ((h - 1) // padding_factor + 1) * padding_factor
pad_w = ((w - 1) // padding_factor + 1) * padding_factor
pad_dims = (0, pad_w - w, 0, pad_h - h)
x = torch.nn.functional.pad(x, pad_dims)

scale_list = (8.0 / scale, 4.0 / scale, 2.0 / scale, 1.0 / scale)

if dtype == torch.float32:
flow, mask, image = self.flownet((x), timestep, scale_list)
else:
flow, mask, image = self.flownet_half((x), timestep, scale_list)

# Return the optical flow and mask
if self.optical_flow:
return torch.cat((flow[:, :, :h, :w], mask[:, :, :h, :w]), 1)

# Return the interpolated frames
alpha = torch.ones((b, 1, h, w), dtype=dtype, device=device)
return torch.cat((image[:, :, :h, :w], alpha), dim=1).contiguous()

class FlowNetNuke(torch.nn.Module):
"""
FlowNetNuke is a module that performs optical flow estimation and frame interpolation using the RIFE algorithm.
Args:
timestep (float): The time interval between consecutive frames. Default is 0.5.
scale (float): The scale factor for resizing the input frames. Default is 1.0.
optical_flow (int): Flag indicating whether to return the optical flow and mask or the interpolated frames.
Set to 1 to return optical flow and mask, and 0 to return interpolated frames. Default is 0.
"""

def __init__(
self,
timestep: float = 0.5,
scale: float = 1.0,
optical_flow: int = 0,
ensemble: int = 0,
):
super().__init__()
self.optical_flow = optical_flow
self.timestep = timestep
self.scale = scale
self.ensemble = ensemble
self.flownet = load_flownet()
self.flownet_half = load_flownet().half()

def forward(self, x):
"""
Forward pass of the RIFE model.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: If `optical_flow` is True, returns the optical flow and mask concatenated along the channel dimension,
with shape (batch_size, 2 * channels, height, width).
If `optical_flow` is False, returns the interpolated frames and alpha channel concatenated along the channel dimension,
with shape (batch_size, (channels + 1), height, width).
"""
b, c, h, w = x.shape
dtype = x.dtype
timestep = self.timestep
ensemble = bool(self.ensemble)
scale = self.scale if self.scale in [0.125, 0.25, 0.5, 1.0, 2.0, 4.0] else 1.0
device = torch.device("cuda") if x.is_cuda else torch.device("cpu")

# Padding
padding_factor = max(128, int(128 / scale))
pad_h = ((h - 1) // padding_factor + 1) * padding_factor
pad_w = ((w - 1) // padding_factor + 1) * padding_factor
pad_dims = (0, pad_w - w, 0, pad_h - h)
x = torch.nn.functional.pad(x, pad_dims)

scale_list = (8.0 / scale, 4.0 / scale, 2.0 / scale, 1.0 / scale)

if dtype == torch.float32:
flow, mask, image = self.flownet((x), timestep, scale_list, ensemble)
else:
flow, mask, image = self.flownet_half((x), timestep, scale_list, ensemble)

# Return the optical flow and mask
if self.optical_flow:
return torch.cat((flow[:, :, :h, :w], mask[:, :, :h, :w]), 1)

# Return the interpolated frames
alpha = torch.ones((b, 1, h, w), dtype=dtype, device=device)
return torch.cat((image[:, :, :h, :w], alpha), dim=1).contiguous()


def trace_rife(model_file=TORCHSCRIPT_MODEL):
"""
Traces the RIFE model using FlowNetNuke and saves the traced flow model.
Returns:
None
"""
with torch.jit.optimized_execution(True):
rife_nuke = torch.jit.script(FlowNetNuke().eval().requires_grad_(False))
model_file = "./nuke/Cattery/RIFE/RIFE.pt"
model_file = TORCHSCRIPT_MODEL
rife_nuke.save(model_file)
LOGGER.info(rife_nuke.code)
LOGGER.info(rife_nuke.graph)
Expand Down

0 comments on commit 2ab34f4

Please sign in to comment.