diff --git a/model/IFNet_HDv3.py b/model/IFNet_HDv3.py index e10cf0c..967e3e0 100755 --- a/model/IFNet_HDv3.py +++ b/model/IFNet_HDv3.py @@ -188,7 +188,11 @@ def __init__(self): self.encode = Head() def forward( - self, x, timestep: float = 0.5, scale_list: List[float] = (8.0, 4.0, 2.0, 1.0) + self, + x, + timestep: float = 0.5, + scale_list: List[float] = (8.0, 4.0, 2.0, 1.0), + ensemble: bool = False, ): channel = x.shape[1] // 2 img0 = x[:, :channel] @@ -207,6 +211,15 @@ def forward( None, scale=scale_list[0], ) + if ensemble: + f_, m_ = self.block0( + torch.cat((img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1), + None, + scale=scale_list[0], + ) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) @@ -220,7 +233,27 @@ def forward( flow, scale=scale_list[1], ) - mask = m0 + if ensemble: + f_, m_ = self.block1( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[1], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) @@ -235,7 +268,27 @@ def forward( flow, scale=scale_list[2], ) - mask = m0 + if ensemble: + f_, m_ = self.block2( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[2], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) @@ -250,7 +303,27 @@ def forward( flow, scale=scale_list[3], ) - mask = m0 + if ensemble: + f_, m_ = self.block3( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[3], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + else: + mask = m0 + flow = flow + fd warped_img0 = warp(img0, flow[:, :2]) warped_img1 = warp(img1, flow[:, 2:4]) diff --git a/model/flownet.pkl b/model/flownet_v4.14.pkl similarity index 100% rename from model/flownet.pkl rename to model/flownet_v4.14.pkl diff --git a/model/flownet_v4.15.pkl b/model/flownet_v4.15.pkl new file mode 100644 index 0000000..c57fbb7 Binary files /dev/null and b/model/flownet_v4.15.pkl differ diff --git a/nuke/Cattery/RIFE/RIFE.cat b/nuke/Cattery/RIFE/RIFE.cat index 4c54467..b0d9e0a 100644 Binary files a/nuke/Cattery/RIFE/RIFE.cat and b/nuke/Cattery/RIFE/RIFE.cat differ diff --git a/nuke/Cattery/RIFE/RIFE.gizmo b/nuke/Cattery/RIFE/RIFE.gizmo index bfb2fd3..f5b0cfc 100644 --- a/nuke/Cattery/RIFE/RIFE.gizmo +++ b/nuke/Cattery/RIFE/RIFE.gizmo @@ -27,7 +27,7 @@ Gizmo { addUserKnob {41 filter l Filter t "Filtering method for the STMap distortion.\n\nNote: Only effective if Channels is set to all." T C_STMap1.filter} addUserKnob {6 skipKeyframes l "Process only intermediate frames" t "When processing keyframes (e.g. 11, 12, 13), RIFE can introduce slight distortion or filtering. \n\nThis option processes only intermediate frames (e.g. 11.5, 12.5) while skipping the keyframes. \n\nThis is useful for cases where the result needs to match the original unmodified frames exactly." +STARTLINE} addUserKnob {20 infoTab l Info} - addUserKnob {26 toolName l "" +STARTLINE T "RIFE v1.0.2 | Released 2024-03-01"} + addUserKnob {26 toolName l "" +STARTLINE T "RIFE v1.1.1 | Released 2024-03-17"} addUserKnob {26 ""} addUserKnob {26 authorName l "" +STARTLINE T "Rafael Silva"} addUserKnob {26 authorMail l "" +STARTLINE T "rafael@rafael.ai"} diff --git a/nuke/Cattery/RIFE/RIFE.pt b/nuke/Cattery/RIFE/RIFE.pt index 822b517..a802564 100644 Binary files a/nuke/Cattery/RIFE/RIFE.pt and b/nuke/Cattery/RIFE/RIFE.pt differ diff --git a/nuke/RIFE_CatFileCreator.nk b/nuke/RIFE_CatFileCreator.nk index 07df647..979e97f 100644 --- a/nuke/RIFE_CatFileCreator.nk +++ b/nuke/RIFE_CatFileCreator.nk @@ -1,6 +1,6 @@ CatFileCreator { - torchScriptFile RIFE.pt - catFile RIFE.cat + torchScriptFile "\[python \{nuke.script_directory()\}]/Cattery/RIFE/RIFE.pt" + catFile "\[python \{nuke.script_directory()\}]/Cattery/RIFE/RIFE.cat" channelsIn rgba.red,rgba.green,rgba.blue,rgba.alpha,forward.u,forward.v,backward.u,backward.v channelsOut rgba.red,rgba.green,rgba.blue,rgba.alpha,depth.Z modelId RIFEv4 @@ -13,4 +13,5 @@ CatFileCreator { addUserKnob {7 scale R 0 8} scale 1 addUserKnob {6 optical_flow +STARTLINE} + addUserKnob {6 ensemble +STARTLINE} } \ No newline at end of file diff --git a/nuke_rife.py b/nuke_rife.py index 6a18b06..5f3bc5d 100644 --- a/nuke_rife.py +++ b/nuke_rife.py @@ -4,7 +4,8 @@ logging.basicConfig(level=logging.INFO) LOGGER = logging.getLogger(__name__) -PATH = "model/flownet.pkl" +PATH = "model/flownet_v4.14.pkl" +TORCHSCRIPT_MODEL = "./nuke/Cattery/RIFE/RIFE.pt" def load_flownet(): @@ -20,53 +21,85 @@ def convert(param): return flownet -def trace_rife(): - class FlowNetNuke(torch.nn.Module): - def __init__( - self, timestep: float = 0.5, scale: float = 1.0, optical_flow: int = 0 - ): - super().__init__() - self.optical_flow = optical_flow - self.timestep = timestep - self.scale = scale - self.flownet = load_flownet() - self.flownet_half = load_flownet().half() - - def forward(self, x): - b, c, h, w = x.shape - dtype = x.dtype - - timestep = self.timestep - scale = ( - self.scale if self.scale in [0.125, 0.25, 0.5, 1.0, 2.0, 4.0] else 1.0 - ) - device = torch.device("cuda") if x.is_cuda else torch.device("cpu") - - # Padding - padding_factor = max(128, int(128 / scale)) - pad_h = ((h - 1) // padding_factor + 1) * padding_factor - pad_w = ((w - 1) // padding_factor + 1) * padding_factor - pad_dims = (0, pad_w - w, 0, pad_h - h) - x = torch.nn.functional.pad(x, pad_dims) - - scale_list = (8.0 / scale, 4.0 / scale, 2.0 / scale, 1.0 / scale) - - if dtype == torch.float32: - flow, mask, image = self.flownet((x), timestep, scale_list) - else: - flow, mask, image = self.flownet_half((x), timestep, scale_list) - - # Return the optical flow and mask - if self.optical_flow: - return torch.cat((flow[:, :, :h, :w], mask[:, :, :h, :w]), 1) - - # Return the interpolated frames - alpha = torch.ones((b, 1, h, w), dtype=dtype, device=device) - return torch.cat((image[:, :, :h, :w], alpha), dim=1).contiguous() - +class FlowNetNuke(torch.nn.Module): + """ + FlowNetNuke is a module that performs optical flow estimation and frame interpolation using the RIFE algorithm. + + Args: + timestep (float): The time interval between consecutive frames. Default is 0.5. + scale (float): The scale factor for resizing the input frames. Default is 1.0. + optical_flow (int): Flag indicating whether to return the optical flow and mask or the interpolated frames. + Set to 1 to return optical flow and mask, and 0 to return interpolated frames. Default is 0. + """ + + def __init__( + self, + timestep: float = 0.5, + scale: float = 1.0, + optical_flow: int = 0, + ensemble: int = 0, + ): + super().__init__() + self.optical_flow = optical_flow + self.timestep = timestep + self.scale = scale + self.ensemble = ensemble + self.flownet = load_flownet() + self.flownet_half = load_flownet().half() + + def forward(self, x): + """ + Forward pass of the RIFE model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor: If `optical_flow` is True, returns the optical flow and mask concatenated along the channel dimension, + with shape (batch_size, 2 * channels, height, width). + If `optical_flow` is False, returns the interpolated frames and alpha channel concatenated along the channel dimension, + with shape (batch_size, (channels + 1), height, width). + """ + b, c, h, w = x.shape + dtype = x.dtype + timestep = self.timestep + ensemble = bool(self.ensemble) + scale = self.scale if self.scale in [0.125, 0.25, 0.5, 1.0, 2.0, 4.0] else 1.0 + device = torch.device("cuda") if x.is_cuda else torch.device("cpu") + + # Padding + padding_factor = max(128, int(128 / scale)) + pad_h = ((h - 1) // padding_factor + 1) * padding_factor + pad_w = ((w - 1) // padding_factor + 1) * padding_factor + pad_dims = (0, pad_w - w, 0, pad_h - h) + x = torch.nn.functional.pad(x, pad_dims) + + scale_list = (8.0 / scale, 4.0 / scale, 2.0 / scale, 1.0 / scale) + + if dtype == torch.float32: + flow, mask, image = self.flownet((x), timestep, scale_list, ensemble) + else: + flow, mask, image = self.flownet_half((x), timestep, scale_list, ensemble) + + # Return the optical flow and mask + if self.optical_flow: + return torch.cat((flow[:, :, :h, :w], mask[:, :, :h, :w]), 1) + + # Return the interpolated frames + alpha = torch.ones((b, 1, h, w), dtype=dtype, device=device) + return torch.cat((image[:, :, :h, :w], alpha), dim=1).contiguous() + + +def trace_rife(model_file=TORCHSCRIPT_MODEL): + """ + Traces the RIFE model using FlowNetNuke and saves the traced flow model. + + Returns: + None + """ with torch.jit.optimized_execution(True): rife_nuke = torch.jit.script(FlowNetNuke().eval().requires_grad_(False)) - model_file = "./nuke/Cattery/RIFE/RIFE.pt" + model_file = TORCHSCRIPT_MODEL rife_nuke.save(model_file) LOGGER.info(rife_nuke.code) LOGGER.info(rife_nuke.graph)