Skip to content

Commit

Permalink
Bump version v1.2.0
Browse files Browse the repository at this point in the history
Add support for video outpainting (model=1)
  • Loading branch information
dan64 committed Jun 29, 2024
1 parent 12aac72 commit c993544
Show file tree
Hide file tree
Showing 5 changed files with 566 additions and 28 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ clip = propainter(clip, clip_mask=clipMask)
# ProPainter using a mask image region
clip = propainter(clip, img_mask_path="sample.png", mask_region=(460,280,68,28))

# ProPainter using outpainting
w = clip.width + 8
h = clip.high + 32
clip = propainter(clip, model = 1, length=50, mask_dilation=0, outpaint_size=(w, h))

```
See `__init__.py` for the description of the parameters.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude = []

[project]
name = "vspropainter"
version = "1.1.0"
version = "1.2.0"
description = "ProPainter function for VapourSynth"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
214 changes: 199 additions & 15 deletions vspropainter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Date: 2024-05-26
version:
LastEditors: Dan64
LastEditTime: 2024-06-05
LastEditTime: 2024-06-23
-------------------------------------------------------------------------------
Description:
-------------------------------------------------------------------------------
Expand All @@ -21,17 +21,79 @@
import torch.nn.functional as F
import vapoursynth as vs
from functools import partial
from vspropainter.propainter_render import ModelProPainter
from vspropainter.propainter_render import ModelProPainterIn, ModelProPainterOut
from vspropainter.propainter_utils import *

__version__ = "1.1.0"
__version__ = "1.2.0"

os.environ["CUDA_MODULE_LOADING"] = "LAZY"

model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "weights")

# @torch.inference_mode()

def propainter(
clip: vs.VideoNode,
length: int = 80,
clip_mask: vs.VideoNode = None,
img_mask_path: str = None,
mask_dilation: int = 4,
neighbor_length: int = 10,
ref_stride: int = 10,
raft_iter: int = 20,
mask_region: tuple[int, int, int, int] = None,
sc_threshold: float = 0.0,
model: int = 0,
outpaint_size: tuple[int, int] = None,
weights_dir: str = model_dir,
enable_fp16: bool = True,
device_index: int = 0,
inference_mode: bool = False
) -> vs.VideoNode:
"""ProPainter: Improving Propagation and Transformer for Video Inpainting
:param clip: Clip to process. Only RGB24 "full range" format is supported.
:param length: Sequence length that the model processes (min. 12 frames). High values will
increase the inference speed but will increase also the memory usage. Default: 80
:param clip_mask: Clip mask, must be of the same size and length of input clip. Default: None
:param img_mask_path: Path of the mask image, must be of the same size of input clip: Default: None
:param mask_dilation: Mask dilation for video and flow masking. Default: 4
:param neighbor_length: Length of local neighboring frames. Low values decrease the memory usage.
High values could help to improve the quality on fast moving sequences.
Default: 10
:param ref_stride: Stride of global reference frames. High values will allow to
reduce the memory usage and increase the inference speed, but could
affect the inference quality. Default: 10
:param raft_iter: Iterations for RAFT inference. Low values will increase the inference
speed but could affect the output quality. Default: 20
:param mask_region: Allow to restrict the region of the mask, format: (width, height, left, top).
The region must be big enough to allow the inference. Default: None
:param model: Model used by ProPainter to render the frames, available values are:
0: Inpainting Mask Mode (using img_mask/clip_mask)
1: Outpainting Mode (using outpaint_size)
Default: 0
:param outpaint_size: Size of extrapolated frames, format: (width, height). Default: None
:param weights_dir: Path string of location of model weights.
:param sc_threshold: If > 0 represent the scene change threshold used to generate the reference frames for
ProPainter, range [0,1]. Default = 0.0
:param enable_fp16: If True use fp16 (half precision) during inference. Default: fp16 (for RTX30 or above)
:param device_index: Device ordinal of the GPU (if = -1 CPU mode is enabled). Default: 0
:param inference_mode: Enable/Disable torch inference mode. Default: False
"""

if model not in (0, 1):
raise vs.Error("propainter: model must be 0 or 1")

if model == 0:
return propainter_inpaint(clip, length, clip_mask, img_mask_path, mask_dilation,
neighbor_length, ref_stride, raft_iter, mask_region, sc_threshold,
weights_dir, enable_fp16, device_index, inference_mode)
else:
return propainter_outpaint(clip, length, outpaint_size, mask_dilation, neighbor_length,
ref_stride, raft_iter, weights_dir, enable_fp16, device_index, inference_mode)


# @torch.inference_mode()
def propainter_inpaint(
clip: vs.VideoNode,
length: int = 100,
clip_mask: vs.VideoNode = None,
Expand Down Expand Up @@ -124,7 +186,8 @@ def propainter(

cache = {}

def inference_img_mask(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame = None, ppaint: ModelProPainter = None,
def inference_img_mask(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame = None,
ppaint: ModelProPainterIn = None,
batch_size: int = 25, use_half: bool = False, sc_thresh: bool = False) -> vs.VideoFrame:

if str(n) not in cache:
Expand Down Expand Up @@ -160,7 +223,7 @@ def inference_img_mask(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame = N
return np_array_to_frame(cache[str(n)], f[1].copy())

def inference_clip_mask(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame = None, m_clip: vs.VideoFrame = None,
ppaint: ModelProPainter = None, batch_size: int = 25,
ppaint: ModelProPainterIn = None, batch_size: int = 25,
use_half: bool = False, sc_thresh: bool = False) -> vs.VideoFrame:

if str(n) not in cache:
Expand Down Expand Up @@ -205,32 +268,153 @@ def inference_clip_mask(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame =
else:
sc_thresh = False

ppaint = ModelProPainter(device, weights_dir, img_mask_path, mask_dilation, neighbor_length,
ref_stride, raft_iter)
ppaint = ModelProPainterIn(device, weights_dir, img_mask_path, mask_dilation, neighbor_length,
ref_stride, raft_iter, (clip.width, clip.height))

base = clip.std.BlankClip(width=clip.width, height=clip.height, keep=True)

if clip_mask is None:
if mask_region is None:
clip_new = base.std.ModifyFrame(clips=[clip, base], selector=partial(inference_img_mask, v_clip=clip,
ppaint=ppaint, batch_size=length, use_half=use_half, sc_thresh=sc_thresh))
ppaint=ppaint, batch_size=length,
use_half=use_half,
sc_thresh=sc_thresh))
else:
ppaint.img_mask_crop(mask_region)
base_c = clip_crop(base, mask_region)
clip_c = clip_crop(clip, mask_region)
v_cropped = base_c.std.ModifyFrame(clips=[clip_c, base_c], selector=partial(inference_img_mask, v_clip=clip_c,
ppaint=ppaint, batch_size=length, use_half=use_half, sc_thresh=sc_thresh))
v_cropped = base_c.std.ModifyFrame(clips=[clip_c, base_c],
selector=partial(inference_img_mask, v_clip=clip_c,
ppaint=ppaint, batch_size=length, use_half=use_half,
sc_thresh=sc_thresh))
clip_new = mask_overlay(clip, v_cropped, x=mask_region[2], y=mask_region[3])
else:
if mask_region is None:
clip_new = base.std.ModifyFrame(clips=[clip, base, clip_mask], selector=partial(inference_clip_mask, v_clip=clip,
m_clip=clip_mask, ppaint=ppaint, batch_size=length, use_half=use_half, sc_thresh=sc_thresh))
clip_new = base.std.ModifyFrame(clips=[clip, base, clip_mask],
selector=partial(inference_clip_mask, v_clip=clip,
m_clip=clip_mask, ppaint=ppaint, batch_size=length,
use_half=use_half, sc_thresh=sc_thresh))
else:
base_c = clip_crop(base, mask_region)
clip_mask_c = clip_crop(clip_mask, mask_region)
clip_c = clip_crop(clip, mask_region)
v_cropped = base_c.std.ModifyFrame(clips=[clip_c, base_c, clip_mask_c], selector=partial(inference_clip_mask,
v_clip=clip_c, m_clip=clip_mask_c, ppaint=ppaint, batch_size=length, use_half=use_half, sc_thresh=sc_thresh))
v_cropped = base_c.std.ModifyFrame(clips=[clip_c, base_c, clip_mask_c],
selector=partial(inference_clip_mask,
v_clip=clip_c, m_clip=clip_mask_c, ppaint=ppaint,
batch_size=length, use_half=use_half,
sc_thresh=sc_thresh))
clip_new = mask_overlay(clip, v_cropped, x=mask_region[2], y=mask_region[3])

return clip_new


def propainter_outpaint(
clip: vs.VideoNode,
length: int = 50,
outpaint_size: tuple[int, int] = None,
mask_dilation: int = 0,
neighbor_length: int = 10,
ref_stride: int = 10,
raft_iter: int = 20,
weights_dir: str = model_dir,
enable_fp16: bool = True,
device_index: int = 0,
inference_mode: bool = False
) -> vs.VideoNode:
"""ProPainter: Improving Propagation and Transformer for Video Outpainting
:param clip: Clip to process. Only RGB24 "full range" format is supported.
:param length: Sequence length that the model processes (min. 10 frames). High values will
increase the inference speed but will increase also the memory usage. Default: 50
:param outpaint_size: Size of extrapolated frames, format: (width, height). Default: None
:param mask_dilation: Mask dilation for video and flow masking. Default: 8
:param neighbor_length: Length of local neighboring frames. Low values decrease the memory usage.
High values could help to improve the quality on fast moving sequences.
Default: 10
:param ref_stride: Stride of global reference frames. High values will allow to
reduce the memory usage and increase the inference speed, but could
affect the inference quality. Default: 10
:param raft_iter: Iterations for RAFT inference. Low values will increase the inference
speed but could affect the output quality. Default: 20
:param weights_dir: Path string of location of model weights.
:param enable_fp16: If True use fp16 (half precision) during inference. Default: fp16 (for RTX30 or above)
:param device_index: Device ordinal of the GPU (if = -1 CPU mode is enabled). Default: 0
:param inference_mode: Enable/Disable torch inference mode. Default: False
"""
if not isinstance(clip, vs.VideoNode):
raise vs.Error("propainter: this is not a clip")

if clip.format.id != vs.RGB24:
raise vs.Error("propainter: only RGB24 format is supported")

if (outpaint_size is None):
raise vs.Error("propainter: please provide the outpainting size (width, height)")

if outpaint_size[0] < clip.width and outpaint_size[1] < clip.height:
raise vs.Error("propainter: outpainting size is lower than clip size")

if device_index != -1 and not torch.cuda.is_available():
raise vs.Error("propainter: CUDA is not available")

if length < 10:
raise vs.Error("propainter: length must be at least 10")

disable_warnings()

if device_index == -1:
device = torch.device("cpu")
use_half = False
else:
device = torch.device("cuda", device_index)
use_half = enable_fp16

# enable torch inference mode
# https://pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html
if inference_mode:
torch.backends.cudnn.benchmark = True
torch.inference_mode()

# ----------------------------------------- INFERENCE -------------------------------------------------------------

cache = {}

def inference_extrapolated_frames(n: int, f: list[vs.VideoFrame], v_clip: vs.VideoFrame = None,
ppaint: ModelProPainterOut = None,
batch_size: int = 25, use_half: bool = False) -> vs.VideoFrame:

if str(n) not in cache:
cache.clear()
# vs.core.log_message(2, "Init Cache at frame_n = " + str(n))

frames = [frame_to_image(f[0])]

for i in range(1, batch_size):

if n + i >= v_clip.num_frames:
break

frame_i = v_clip.get_frame(n + i)

frames.append(frame_to_image(frame_i))

output = ppaint.get_extrapolated_frames(video_frames=frames, batch_size=batch_size, use_half=use_half)

for i in range(len(output)):
cache[str(n + i)] = output[i]

return np_array_to_frame(cache[str(n)], f[1].copy())

# ----------------------------------------- ModifyFrame -----------------------------------------------------------

ppaint = ModelProPainterOut(device, weights_dir, mask_dilation, outpaint_size, neighbor_length,
ref_stride, raft_iter, (clip.width, clip.height))

base = clip.std.BlankClip(width=outpaint_size[0], height=outpaint_size[1], keep=True)

clip_new = base.std.ModifyFrame(clips=[clip, base],
selector=partial(inference_extrapolated_frames,
v_clip=clip,
ppaint=ppaint,
batch_size=length,
use_half=use_half))
return clip_new
Loading

0 comments on commit c993544

Please sign in to comment.