Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xinntao/Real-ESRGAN#584 - custom backend device (auto, cuda, m1, cpu) #1

Merged
merged 3 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ dmypy.json

# Pyre type checker
.pyre/

# VSCode
.vscode/
.vscode/*
21 changes: 21 additions & 0 deletions inference_realesrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from torch.cuda import is_available as cudaIsAvailable
from torch.backends.mps import is_available as mpsIsAvailable

def main():
"""Inference demo for Real-ESRGAN.
Expand Down Expand Up @@ -52,6 +54,8 @@ def main():
parser.add_argument(
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')

parser.add_argument('--backend_type', type=str, default='auto', choices=['auto', 'cuda', 'cpu', 'mps'], help='backend type. Options: auto(cuda-cpu) | cuda | cpu | mps')

args = parser.parse_args()

# determine models according to model names
Expand Down Expand Up @@ -103,6 +107,21 @@ def main():
model_path = [model_path, wdn_model_path]
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]

# deternime backend type (cpu, cuda, mps)
if args.backend_type == 'auto':
if cudaIsAvailable():
backend_type = 'cuda'
elif mpsIsAvailable():
backend_type = 'mps'
else:
backend_type = 'cpu'
elif args.backend_type == 'cuda' and cudaIsAvailable():
backend_type = 'cuda'
elif args.backend_type == 'mps' and mpsIsAvailable():
backend_type = 'mps'
else:
backend_type = 'cpu'

# restorer
upsampler = RealESRGANer(
scale=netscale,
Expand All @@ -113,6 +132,7 @@ def main():
tile_pad=args.tile_pad,
pre_pad=args.pre_pad,
half=not args.fp32,
device=backend_type,
gpu_id=args.gpu_id)

if args.face_enhance: # Use GFPGAN for face enhancement
Expand All @@ -122,6 +142,7 @@ def main():
upscale=args.outscale,
arch='clean',
channel_multiplier=2,
device='cpu', # <--- MPS is not supported yet, crash pas runtime. TODO: FIX THIS
bg_upsampler=upsampler)
os.makedirs(args.output, exist_ok=True)

Expand Down
2 changes: 1 addition & 1 deletion realesrgan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
self.img = img.unsqueeze(0).contiguous().to(self.device)
if self.half:
self.img = self.img.half()

Expand Down
Loading