Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp2023-554 #2402

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./super-resolution_dynamic.py', '../../_base_/backends/tensorrt.py']
codebase_config = dict(type='mmagic', task='VideoSuperResolution')
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 256, 256],
max_shape=[1, 3, 512, 512])))
])
3 changes: 1 addition & 2 deletions mmdeploy/backend/sdk/export_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config,
transform['to_float'] = False
transform['mean'] = [0, 0, 0]
transform['std'] = [1, 1, 1]
if transforms[0]['type'] != 'Lift':
if transforms[0]['type'] not in ['Lift', 'GenerateSegmentIndices'] :
assert transforms[0]['type'] == 'LoadImageFromFile', \
'The first item type of pipeline should be LoadImageFromFile'
return dict(
Expand Down Expand Up @@ -352,7 +352,6 @@ def export2SDK(deploy_cfg: Union[str, mmengine.Config],
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir, device)
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir, device)
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)

mmengine.dump(
deploy_info,
'{}/deploy.json'.format(work_dir),
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmagic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deploy import MMEditing, SuperResolution
from .deploy import MMEditing, SuperResolution, VideoSuperResolution

__all__ = ['MMEditing', 'SuperResolution']
__all__ = ['MMEditing', 'SuperResolution', 'VideoSuperResolution']
3 changes: 2 additions & 1 deletion mmdeploy/codebase/mmagic/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.codebase.mmagic.deploy.mmediting import MMEditing
from mmdeploy.codebase.mmagic.deploy.super_resolution import SuperResolution
from mmdeploy.codebase.mmagic.deploy.video_super_resolution import VideoSuperResolution

__all__ = ['MMEditing', 'SuperResolution']
__all__ = ['MMEditing', 'SuperResolution', 'VideoSuperResolution']
335 changes: 335 additions & 0 deletions mmdeploy/codebase/mmagic/deploy/video_super_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import mmengine
import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import pseudo_collate
from mmengine.model import BaseDataPreprocessor
import os.path as osp
import mmcv
from typing import List
from mmengine.dataset import Compose

from mmdeploy.codebase.base import BaseTask
from mmdeploy.codebase.mmagic.deploy.mmediting import MMAGIC_TASK
from mmdeploy.utils import Task, get_input_shape, get_root_logger
from mmagic.apis.inferencers.base_mmagic_inferencer import (
InputsType,
PredType,
ResType,
)
from mmagic.apis.inferencers.inference_functions import VIDEO_EXTENSIONS, pad_sequence
import glob
import os
from mmagic.utils import tensor2img
import cv2
from mmengine.utils import ProgressBar
from mmengine.logging import MMLogger

def process_model_config(model_cfg: mmengine.Config,
imgs: Union[Sequence[str], Sequence[np.ndarray]],
input_shape: Optional[Sequence[int]] = None):
"""Process the model config.

Args:
model_cfg (mmengine.Config): The model config.
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
data type are List[str], List[np.ndarray].
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.

Returns:
mmengine.Config: the model config after processing.
"""
config = deepcopy(model_cfg)
if not hasattr(config, 'test_pipeline'):
config.__setattr__('test_pipeline', config.val_pipeline)
keys_to_remove = ['gt', 'gt_path']
# MMagic doesn't support LoadImageFromWebcam.
# Remove "LoadImageFromFile" and related metakeys.
load_from_file = isinstance(imgs[0], str)
is_static_cfg = input_shape is not None
if not load_from_file:
config.test_pipeline.pop(0)
keys_to_remove.append('lq_path')

# Fix the input shape by 'Resize'
if is_static_cfg:
resize = {
'type': 'Resize',
'scale': (input_shape[0], input_shape[1]),
'keys': ['img']
}
config.test_pipeline.insert(1, resize)
for key in keys_to_remove:
for pipeline in list(config.test_pipeline):
if 'key' in pipeline and key == pipeline['key']:
config.test_pipeline.remove(pipeline)
if 'keys' in pipeline:
while key in pipeline['keys']:
pipeline['keys'].remove(key)
if len(pipeline['keys']) == 0:
config.test_pipeline.remove(pipeline)
if 'meta_keys' in pipeline:
while key in pipeline['meta_keys']:
pipeline['meta_keys'].remove(key)
return config

@MMAGIC_TASK.register_module(Task.VIDEO_SUPER_RESOLUTION.value)
class VideoSuperResolution(BaseTask):
"""BaseTask class of video super resolution task.

Args:
model_cfg (mmengine.Config): Model config file.
deploy_cfg (mmengine.Config): Deployment config file.
device (str): A string specifying device type.
"""

extra_parameters = dict(
start_idx=0, filename_tmpl="{:08d}.png", window_size=0, max_seq_len=None
)

def __init__(
self, model_cfg: mmengine.Config, deploy_cfg: mmengine.Config, device: str
):
super(VideoSuperResolution, self).__init__(model_cfg, deploy_cfg, device)

def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.

Args:
model_files (Sequence[str]): Input model files. Default is None.

Returns:
nn.Module: An initialized backend model.
"""
from .video_super_resolution_model import build_super_resolution_model
data_preprocessor = deepcopy(
self.model_cfg.model.get('data_preprocessor', {}))
data_preprocessor.setdefault('type', 'mmagic.EditDataPreprocessor')
model = build_super_resolution_model(
model_files,
self.model_cfg,
self.deploy_cfg,
device=self.device,
data_preprocessor=data_preprocessor,
**kwargs)
return model

def preprocess(self, video: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.

Args:
video(InputsType): Video to be restored by models.

Returns:
results(InputsType): Results of preprocess.
"""
# build the data pipeline
if self.model_cfg.get("demo_pipeline", None):
test_pipeline = self.model_cfg.demo_pipeline
elif self.model_cfg.get("test_pipeline", None):
test_pipeline = self.model_cfg.test_pipeline
else:
test_pipeline = self.model_cfg.val_pipeline

# check if the input is a video
file_extension = osp.splitext(video)[1]
if file_extension in VIDEO_EXTENSIONS:
video_reader = mmcv.VideoReader(video)
# load the images
data = dict(img=[], img_path=None, key=video)
for frame in video_reader:
data["img"].append(np.flip(frame, axis=2))

# remove the data loading pipeline
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline["type"] not in [
"GenerateSegmentIndices",
"LoadImageFromFile",
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
else:
# the first element in the pipeline must be
# 'GenerateSegmentIndices'
if test_pipeline[0]["type"] != "GenerateSegmentIndices":
raise TypeError(
"The first element in the pipeline must be "
f'"GenerateSegmentIndices", but got '
f'"{test_pipeline[0]["type"]}".'
)

# specify start_idx and filename_tmpl
test_pipeline[0]["start_idx"] = self.extra_parameters["start_idx"]
test_pipeline[0]["filename_tmpl"] = self.extra_parameters["filename_tmpl"]

# prepare data
sequence_length = len(glob.glob(osp.join(video, "*")))
lq_folder = osp.dirname(video)
key = osp.basename(video)
data = dict(
img_path=lq_folder, gt_path="", key=key, sequence_length=sequence_length
)

# compose the pipeline
test_pipeline = Compose(test_pipeline)
data = test_pipeline(data)
results = data["inputs"].unsqueeze(0) / 255.0 # in cpu
data["inputs"] = results
return data

def create_input(
self,
video: InputsType,
input_shape: Sequence[int] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None,
) -> Tuple[Dict, torch.Tensor]:
"""Create input for editing processor.

Args:
imgs (str | np.ndarray): Input image(s).
input_shape (Sequence[int] | None): A list of two integer in
(width, height) format specifying input shape. Defaults to `None`.
data_preprocessor (BaseDataPreprocessor): The data preprocessor
of the model. Default to `None`.

Returns:
tuple: (data, img), meta information for the input image and input.
"""
data = self.preprocess(video)
return data, BaseTask.get_tensor_from_input(data)

def visualize(self, preds: PredType, result_out_dir: str = "") -> List[np.ndarray]:
"""Visualize result of a model. mmagic does not have visualizer, so
write visualize function directly.

Args:
model (nn.Module): Input model.
image (str | np.ndarray): Input image to draw predictions on.
result (list | np.ndarray): A list of result.
output_file (str): Output file to save drawn image.
window_name (str): The name of visualization window. Defaults to
an empty string.
show_result (bool): Whether to show result in windows, defaults
to `False`.
"""

file_extension = os.path.splitext(result_out_dir)[1]
mmengine.utils.mkdir_or_exist(osp.dirname(result_out_dir))
prog_bar = ProgressBar(preds.size(1))
if file_extension in VIDEO_EXTENSIONS: # save as video
h, w = preds.shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video_writer = cv2.VideoWriter(result_out_dir, fourcc, 25, (w, h))
for i in range(0, preds.size(1)):
img = tensor2img(preds[:, i, :, :, :])
video_writer.write(img.astype(np.uint8))
prog_bar.update()
cv2.destroyAllWindows()
video_writer.release()
else:
for i in range(
self.extra_parameters["start_idx"],
self.extra_parameters["start_idx"] + preds.size(1),
):
output_i = preds[:, i - self.extra_parameters["start_idx"], :, :, :]
output_i = tensor2img(output_i)
filename_tmpl = self.extra_parameters["filename_tmpl"]
save_path_i = f"{result_out_dir}/{filename_tmpl.format(i)}"
mmcv.imwrite(output_i, save_path_i)
prog_bar.update()

logger: MMLogger = MMLogger.get_current_instance()
logger.info(f"Output video is save at {result_out_dir}.")
return []

@staticmethod
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
"""Get a certain partition config for mmagic.

Args:
partition_type (str): A string specifying partition type.

Returns:
dict: A dictionary of partition config.
"""
raise NotImplementedError

@staticmethod
def get_tensor_from_input(input_data: Dict[str, Any]) -> torch.Tensor:
"""Get input tensor from input data.

Args:
input_data (dict): Input data containing meta info
and image tensor.
Returns:
torch.Tensor: An image in `Tensor`.
"""
return input_data['img']

def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.

Return:
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
preprocess = model_cfg.test_pipeline

preprocess.insert(1, model_cfg.model.data_preprocessor)
preprocess.insert(2, dict(type='ImageToTensor', keys=['img']))
transforms = preprocess
for i, transform in enumerate(transforms):
if 'keys' in transform and transform['keys'] == ['lq']:
transform['keys'] = ['img']
if 'key' in transform and transform['key'] == 'lq':
transform['key'] = 'img'
if transform['type'] == 'DataPreprocessor':
transform['type'] = 'Normalize'
transform['to_rgb'] = transform.get('to_rgb', False)
if transform['type'] == 'PackInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
return transforms

def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.

Return:
dict: Postprocess config for super resolution.
"""
from mmdeploy.utils import get_task_type
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
task = get_task_type(self.deploy_cfg)
component = task_map[task]['component']
post_processor = {'type': component}
return post_processor

def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.

Return:
str: the name of the model.
"""
assert 'generator' in self.model_cfg.model, 'generator not in model '
'config'
assert 'type' in self.model_cfg.model.generator, 'generator contains '
'no type'
name = self.model_cfg.model.generator.type.lower()
return name
Loading