diff --git a/VERSION b/VERSION index c5807f9..e396b40 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.3.1 +0.2.0.0 diff --git a/facexlib/assessment/__init__.py b/facexlib/assessment/__init__.py new file mode 100644 index 0000000..9e59af0 --- /dev/null +++ b/facexlib/assessment/__init__.py @@ -0,0 +1,19 @@ +import torch + +from facexlib.utils import load_file_from_url +from .hyperiqa_net import HyperIQA + + +def init_assessment_model(model_name, half=False, device='cuda'): + if model_name == 'hypernet': + model = HyperIQA(16, 112, 224, 112, 56, 28, 14, 7) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + # load the pre-trained hypernet model + hypernet_model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + model.hypernet.load_state_dict((torch.load(hypernet_model_path, map_location=lambda storage, loc: storage))) + model = model.eval() + model = model.to(device) + return model diff --git a/facexlib/assessment/hyperiqa_net.py b/facexlib/assessment/hyperiqa_net.py new file mode 100644 index 0000000..216fbac --- /dev/null +++ b/facexlib/assessment/hyperiqa_net.py @@ -0,0 +1,298 @@ +import torch as torch +import torch.nn as nn +from torch.nn import functional as F + + +class HyperIQA(nn.Module): + """ + Combine the hypernet and target network within a network. + """ + + def __init__(self, *args): + super(HyperIQA, self).__init__() + self.hypernet = HyperNet(*args) + + def forward(self, img): + net_params = self.hypernet(img) + # build the target network + target_net = TargetNet(net_params) + for param in target_net.parameters(): + param.requires_grad = False + # predict the face quality + pred = target_net(net_params['target_in_vec']) + return pred + + +class HyperNet(nn.Module): + """ + Hyper network for learning perceptual rules. + Args: + lda_out_channels: local distortion aware module output size. + hyper_in_channels: input feature channels for hyper network. + target_in_size: input vector size for target network. + target_fc(i)_size: fully connection layer size of target network. + feature_size: input feature map width/height for hyper network. + Note: + For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'. # noqa E501 + """ + + def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size, + target_fc3_size, target_fc4_size, feature_size): + super(HyperNet, self).__init__() + + self.hyperInChn = hyper_in_channels + self.target_in_size = target_in_size + self.f1 = target_fc1_size + self.f2 = target_fc2_size + self.f3 = target_fc3_size + self.f4 = target_fc4_size + self.feature_size = feature_size + + self.res = resnet50_backbone(lda_out_channels, target_in_size) + + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + + # Conv layers for resnet output features + self.conv1 = nn.Sequential( + nn.Conv2d(2048, 1024, 1, padding=(0, 0)), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 1, padding=(0, 0)), + nn.ReLU(inplace=True), nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)), nn.ReLU(inplace=True)) + + # Hyper network part, conv for generating target fc weights, fc for generating target fc biases + self.fc1w_conv = nn.Conv2d( + self.hyperInChn, int(self.target_in_size * self.f1 / feature_size**2), 3, padding=(1, 1)) + self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1) + + self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size**2), 3, padding=(1, 1)) + self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2) + + self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size**2), 3, padding=(1, 1)) + self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3) + + self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size**2), 3, padding=(1, 1)) + self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4) + + self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4) + self.fc5b_fc = nn.Linear(self.hyperInChn, 1) + + def forward(self, img): + feature_size = self.feature_size + + res_out = self.res(img) + + # input vector for target net + target_in_vec = res_out['target_in_vec'].view(-1, self.target_in_size, 1, 1) + + # input features for hyper net + hyper_in_feat = self.conv1(res_out['hyper_in_feat']).view(-1, self.hyperInChn, feature_size, feature_size) + + # generating target net weights & biases + target_fc1w = self.fc1w_conv(hyper_in_feat).view(-1, self.f1, self.target_in_size, 1, 1) + target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f1) + + target_fc2w = self.fc2w_conv(hyper_in_feat).view(-1, self.f2, self.f1, 1, 1) + target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f2) + + target_fc3w = self.fc3w_conv(hyper_in_feat).view(-1, self.f3, self.f2, 1, 1) + target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f3) + + target_fc4w = self.fc4w_conv(hyper_in_feat).view(-1, self.f4, self.f3, 1, 1) + target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f4) + + target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1, self.f4, 1, 1) + target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1) + + out = {} + out['target_in_vec'] = target_in_vec + out['target_fc1w'] = target_fc1w + out['target_fc1b'] = target_fc1b + out['target_fc2w'] = target_fc2w + out['target_fc2b'] = target_fc2b + out['target_fc3w'] = target_fc3w + out['target_fc3b'] = target_fc3b + out['target_fc4w'] = target_fc4w + out['target_fc4b'] = target_fc4b + out['target_fc5w'] = target_fc5w + out['target_fc5b'] = target_fc5b + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNetBackbone(nn.Module): + + def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000): + super(ResNetBackbone, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + # local distortion aware module + self.lda1_pool = nn.Sequential( + nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), + nn.AvgPool2d(7, stride=7), + ) + self.lda1_fc = nn.Linear(16 * 64, lda_out_channels) + + self.lda2_pool = nn.Sequential( + nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), + nn.AvgPool2d(7, stride=7), + ) + self.lda2_fc = nn.Linear(32 * 16, lda_out_channels) + + self.lda3_pool = nn.Sequential( + nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), + nn.AvgPool2d(7, stride=7), + ) + self.lda3_fc = nn.Linear(64 * 4, lda_out_channels) + + self.lda4_pool = nn.AvgPool2d(7, stride=7) + self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + + # the same effect as lda operation in the paper, but save much more memory + lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1)) + x = self.layer2(x) + lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1)) + x = self.layer3(x) + lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1)) + x = self.layer4(x) + lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1)) + + vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1) + + out = {} + out['hyper_in_feat'] = x + out['target_in_vec'] = vec + + return out + + +def resnet50_backbone(lda_out_channels, in_chn, **kwargs): + """Constructs a ResNet-50 model_hyper.""" + model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs) + return model + + +class TargetNet(nn.Module): + """ + Target network for quality prediction. + """ + + def __init__(self, paras): + super(TargetNet, self).__init__() + self.l1 = nn.Sequential( + TargetFC(paras['target_fc1w'], paras['target_fc1b']), + nn.Sigmoid(), + ) + self.l2 = nn.Sequential( + TargetFC(paras['target_fc2w'], paras['target_fc2b']), + nn.Sigmoid(), + ) + + self.l3 = nn.Sequential( + TargetFC(paras['target_fc3w'], paras['target_fc3b']), + nn.Sigmoid(), + ) + + self.l4 = nn.Sequential( + TargetFC(paras['target_fc4w'], paras['target_fc4b']), + nn.Sigmoid(), + TargetFC(paras['target_fc5w'], paras['target_fc5b']), + ) + + def forward(self, x): + q = self.l1(x) + # q = F.dropout(q) + q = self.l2(q) + q = self.l3(q) + q = self.l4(q).squeeze() + return q + + +class TargetFC(nn.Module): + """ + Fully connection operations for target net + Note: + Weights & biases are different for different images in a batch, + thus here we use group convolution for calculating images in a batch with individual weights & biases. + """ + + def __init__(self, weight, bias): + super(TargetFC, self).__init__() + self.weight = weight + self.bias = bias + + def forward(self, input_): + + input_re = input_.view(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3]) + weight_re = self.weight.view(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], + self.weight.shape[3], self.weight.shape[4]) + bias_re = self.bias.view(self.bias.shape[0] * self.bias.shape[1]) + out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0]) + + return out.view(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3]) diff --git a/facexlib/headpose/__init__.py b/facexlib/headpose/__init__.py new file mode 100644 index 0000000..e02a334 --- /dev/null +++ b/facexlib/headpose/__init__.py @@ -0,0 +1,19 @@ +import torch + +from facexlib.utils import load_file_from_url +from .hopenet_arch import HopeNet + + +def init_headpose_model(model_name, half=False, device='cuda'): + if model_name == 'hopenet': + model = HopeNet('resnet', [3, 4, 6, 3], 66) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage)['params'] + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/facexlib/headpose/hopenet_arch.py b/facexlib/headpose/hopenet_arch.py new file mode 100644 index 0000000..b3a0141 --- /dev/null +++ b/facexlib/headpose/hopenet_arch.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torchvision + + +class HopeNet(nn.Module): + # Hopenet with 3 output layers for yaw, pitch and roll + # Predicts Euler angles by binning and regression with the expected value + def __init__(self, block, layers, num_bins): + super(HopeNet, self).__init__() + if block == 'resnet': + block = torchvision.models.resnet.Bottleneck + self.inplanes = 64 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) + self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) + self.fc_roll = nn.Linear(512 * block.expansion, num_bins) + + self.idx_tensor = torch.arange(66).float() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + @staticmethod + def softmax_temperature(tensor, temperature): + result = torch.exp(tensor / temperature) + result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result)) + return result + + def bin2degree(self, predict): + predict = self.softmax_temperature(predict, 1) + return torch.sum(predict * self.idx_tensor.type_as(predict), 1) * 3 - 99 + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + pre_yaw = self.fc_yaw(x) + pre_pitch = self.fc_pitch(x) + pre_roll = self.fc_roll(x) + + yaw = self.bin2degree(pre_yaw) + pitch = self.bin2degree(pre_pitch) + roll = self.bin2degree(pre_roll) + return yaw, pitch, roll diff --git a/facexlib/matting/__init__.py b/facexlib/matting/__init__.py new file mode 100644 index 0000000..3301590 --- /dev/null +++ b/facexlib/matting/__init__.py @@ -0,0 +1,26 @@ +import torch +from copy import deepcopy + +from facexlib.utils import load_file_from_url +from .modnet import MODNet + + +def init_matting_model(model_name='modnet', half=False, device='cuda'): + if model_name == 'modnet': + model = MODNet(backbone_pretrained=False) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + # TODO: clean pretrained model + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/facexlib/matting/backbone.py b/facexlib/matting/backbone.py new file mode 100644 index 0000000..4cb295f --- /dev/null +++ b/facexlib/matting/backbone.py @@ -0,0 +1,80 @@ +import os +import torch +import torch.nn as nn + +from .mobilenetv2 import MobileNetV2 + + +class BaseBackbone(nn.Module): + """ Superclass of Replaceable Backbone Model for Semantic Estimation + """ + + def __init__(self, in_channels): + super(BaseBackbone, self).__init__() + self.in_channels = in_channels + + self.model = None + self.enc_channels = [] + + def forward(self, x): + raise NotImplementedError + + def load_pretrained_ckpt(self): + raise NotImplementedError + + +class MobileNetV2Backbone(BaseBackbone): + """ MobileNetV2 Backbone + """ + + def __init__(self, in_channels): + super(MobileNetV2Backbone, self).__init__(in_channels) + + self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) + self.enc_channels = [16, 24, 32, 96, 1280] + + def forward(self, x): + # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) + x = self.model.features[0](x) + x = self.model.features[1](x) + enc2x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) + x = self.model.features[2](x) + x = self.model.features[3](x) + enc4x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) + x = self.model.features[4](x) + x = self.model.features[5](x) + x = self.model.features[6](x) + enc8x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) + x = self.model.features[7](x) + x = self.model.features[8](x) + x = self.model.features[9](x) + x = self.model.features[10](x) + x = self.model.features[11](x) + x = self.model.features[12](x) + x = self.model.features[13](x) + enc16x = x + + # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) + x = self.model.features[14](x) + x = self.model.features[15](x) + x = self.model.features[16](x) + x = self.model.features[17](x) + x = self.model.features[18](x) + enc32x = x + return [enc2x, enc4x, enc8x, enc16x, enc32x] + + def load_pretrained_ckpt(self): + # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch + ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' + if not os.path.exists(ckpt_path): + print('cannot find the pretrained mobilenetv2 backbone') + exit() + + ckpt = torch.load(ckpt_path) + self.model.load_state_dict(ckpt) diff --git a/facexlib/matting/mobilenetv2.py b/facexlib/matting/mobilenetv2.py new file mode 100644 index 0000000..c649586 --- /dev/null +++ b/facexlib/matting/mobilenetv2.py @@ -0,0 +1,192 @@ +""" This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" + +import math +import torch +from torch import nn + +# ------------------------------------------------------------------------------ +# Useful functions +# ------------------------------------------------------------------------------ + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def conv_bn(inp, oup, stride): + return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)) + + +# ------------------------------------------------------------------------------ +# Class of Inverted Residual block +# ------------------------------------------------------------------------------ + + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride, expansion, dilation=1): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expansion) + self.use_res_connect = self.stride == 1 and inp == oup + + if expansion == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +# ------------------------------------------------------------------------------ +# Class of MobileNetV2 +# ------------------------------------------------------------------------------ + + +class MobileNetV2(nn.Module): + + def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): + super(MobileNetV2, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [expansion, 24, 2, 2], + [expansion, 32, 3, 2], + [expansion, 64, 4, 2], + [expansion, 96, 3, 1], + [expansion, 160, 3, 2], + [expansion, 320, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(input_channel * alpha, 8) + self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel + self.features = [conv_bn(self.in_channels, input_channel, 2)] + + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = _make_divisible(int(c * alpha), 8) + for i in range(n): + if i == 0: + self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) + else: + self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) + input_channel = output_channel + + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + if self.num_classes is not None: + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # Initialize weights + self._init_weights() + + def forward(self, x): + # Stage1 + x = self.features[0](x) + x = self.features[1](x) + # Stage2 + x = self.features[2](x) + x = self.features[3](x) + # Stage3 + x = self.features[4](x) + x = self.features[5](x) + x = self.features[6](x) + # Stage4 + x = self.features[7](x) + x = self.features[8](x) + x = self.features[9](x) + x = self.features[10](x) + x = self.features[11](x) + x = self.features[12](x) + x = self.features[13](x) + # Stage5 + x = self.features[14](x) + x = self.features[15](x) + x = self.features[16](x) + x = self.features[17](x) + x = self.features[18](x) + + # Classification + if self.num_classes is not None: + x = x.mean(dim=(2, 3)) + x = self.classifier(x) + + # Output + return x + + def _load_pretrained_model(self, pretrained_file): + pretrain_dict = torch.load(pretrained_file, map_location='cpu') + model_dict = {} + state_dict = self.state_dict() + print('[MobileNetV2] Loading pretrained model...') + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + else: + print(k, 'is ignored') + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() diff --git a/facexlib/matting/modnet.py b/facexlib/matting/modnet.py new file mode 100644 index 0000000..cd23c38 --- /dev/null +++ b/facexlib/matting/modnet.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import MobileNetV2Backbone + +# ------------------------------------------------------------------------------ +# MODNet Basic Modules +# ------------------------------------------------------------------------------ + + +class IBNorm(nn.Module): + """ Combine Instance Norm and Batch Norm into One Layer + """ + + def __init__(self, in_channels): + super(IBNorm, self).__init__() + in_channels = in_channels + self.bnorm_channels = int(in_channels / 2) + self.inorm_channels = in_channels - self.bnorm_channels + + self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) + self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) + + def forward(self, x): + bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) + in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) + + return torch.cat((bn_x, in_x), 1) + + +class Conv2dIBNormRelu(nn.Module): + """ Convolution + IBNorm + ReLu + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + with_ibn=True, + with_relu=True): + super(Conv2dIBNormRelu, self).__init__() + + layers = [ + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias) + ] + + if with_ibn: + layers.append(IBNorm(out_channels)) + if with_relu: + layers.append(nn.ReLU(inplace=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class SEBlock(nn.Module): + """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf + """ + + def __init__(self, in_channels, out_channels, reduction=1): + super(SEBlock, self).__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(in_channels, int(in_channels // reduction), bias=False), nn.ReLU(inplace=True), + nn.Linear(int(in_channels // reduction), out_channels, bias=False), nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + w = self.pool(x).view(b, c) + w = self.fc(w).view(b, c, 1, 1) + + return x * w.expand_as(x) + + +# ------------------------------------------------------------------------------ +# MODNet Branches +# ------------------------------------------------------------------------------ + + +class LRBranch(nn.Module): + """ Low Resolution Branch of MODNet + """ + + def __init__(self, backbone): + super(LRBranch, self).__init__() + + enc_channels = backbone.enc_channels + + self.backbone = backbone + self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) + self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) + self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) + self.conv_lr = Conv2dIBNormRelu( + enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) + + def forward(self, img, inference): + enc_features = self.backbone.forward(img) + enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] + + enc32x = self.se_block(enc32x) + lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) + lr16x = self.conv_lr16x(lr16x) + lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) + lr8x = self.conv_lr8x(lr8x) + + pred_semantic = None + if not inference: + lr = self.conv_lr(lr8x) + pred_semantic = torch.sigmoid(lr) + + return pred_semantic, lr8x, [enc2x, enc4x] + + +class HRBranch(nn.Module): + """ High Resolution Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(HRBranch, self).__init__() + + self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) + self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) + + self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) + self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) + + self.conv_hr4x = nn.Sequential( + Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr2x = nn.Sequential( + Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), + ) + + self.conv_hr = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), + Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, enc2x, enc4x, lr8x, inference): + img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False) + img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False) + + enc2x = self.tohr_enc2x(enc2x) + hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) + + enc4x = self.tohr_enc4x(enc4x) + hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) + + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) + + hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) + hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) + + pred_detail = None + if not inference: + hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) + hr = self.conv_hr(torch.cat((hr, img), dim=1)) + pred_detail = torch.sigmoid(hr) + + return pred_detail, hr2x + + +class FusionBranch(nn.Module): + """ Fusion Branch of MODNet + """ + + def __init__(self, hr_channels, enc_channels): + super(FusionBranch, self).__init__() + self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) + + self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) + self.conv_f = nn.Sequential( + Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), + Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), + ) + + def forward(self, img, lr8x, hr2x): + lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) + lr4x = self.conv_lr4x(lr4x) + lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) + + f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) + f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) + f = self.conv_f(torch.cat((f, img), dim=1)) + pred_matte = torch.sigmoid(f) + + return pred_matte + + +# ------------------------------------------------------------------------------ +# MODNet +# ------------------------------------------------------------------------------ + + +class MODNet(nn.Module): + """ Architecture of MODNet + """ + + def __init__(self, in_channels=3, hr_channels=32, backbone_pretrained=True): + super(MODNet, self).__init__() + + self.in_channels = in_channels + self.hr_channels = hr_channels + self.backbone_pretrained = backbone_pretrained + + self.backbone = MobileNetV2Backbone(self.in_channels) + + self.lr_branch = LRBranch(self.backbone) + self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) + self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + self._init_conv(m) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): + self._init_norm(m) + + if self.backbone_pretrained: + self.backbone.load_pretrained_ckpt() + + def forward(self, img, inference): + pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) + pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) + pred_matte = self.f_branch(img, lr8x, hr2x) + + return pred_semantic, pred_detail, pred_matte + + def freeze_norm(self): + norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] + for m in self.modules(): + for n in norm_types: + if isinstance(m, n): + m.eval() + continue + + def _init_conv(self, conv): + nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + def _init_norm(self, norm): + if norm.weight is not None: + nn.init.constant_(norm.weight, 1) + nn.init.constant_(norm.bias, 0) diff --git a/facexlib/parsing/__init__.py b/facexlib/parsing/__init__.py new file mode 100644 index 0000000..12aca1f --- /dev/null +++ b/facexlib/parsing/__init__.py @@ -0,0 +1,19 @@ +import torch + +from facexlib.utils import load_file_from_url +from .bisenet import BiSeNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='facexlib/weights', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/facexlib/parsing/bisenet.py b/facexlib/parsing/bisenet.py new file mode 100644 index 0000000..3898cab --- /dev/null +++ b/facexlib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/facexlib/parsing/resnet.py b/facexlib/parsing/resnet.py new file mode 100644 index 0000000..fec8e82 --- /dev/null +++ b/facexlib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/facexlib/utils/__init__.py b/facexlib/utils/__init__.py index a91e391..706e077 100644 --- a/facexlib/utils/__init__.py +++ b/facexlib/utils/__init__.py @@ -1,6 +1,7 @@ from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back -from .misc import load_file_from_url +from .misc import img2tensor, load_file_from_url, scandir __all__ = [ - 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back' + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back', + 'img2tensor', 'scandir' ] diff --git a/facexlib/utils/face_restoration_helper.py b/facexlib/utils/face_restoration_helper.py index 45c0ab8..22198a2 100644 --- a/facexlib/utils/face_restoration_helper.py +++ b/facexlib/utils/face_restoration_helper.py @@ -52,7 +52,8 @@ def __init__(self, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', - template_3points=True): + template_3points=True, + pad_blur=False): self.template_3points = template_3points # improve robustness self.upscale_factor = upscale_factor # the cropped face ratio based on the square face @@ -66,12 +67,15 @@ def __init__(self, # standard 5 landmarks for FFHQ faces with 512 x 512 self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], [201.26117, 371.41043], [313.08905, 371.15118]]) - + self.face_template = self.face_template * (face_size / 512.0) if self.crop_ratio[0] > 1: self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 if self.crop_ratio[1] > 1: self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False self.all_landmarks_5 = [] self.det_faces = [] @@ -79,6 +83,7 @@ def __init__(self, self.inverse_affine_matrices = [] self.cropped_faces = [] self.restored_faces = [] + self.pad_input_imgs = [] # init face detection model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -95,9 +100,17 @@ def read_image(self, img): else: self.input_img = img - def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False, pad_blur=False): + def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False, resize=None, blur_ratio=0.01): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = min(h, w) / resize + h, w = int(h / scale), int(w / scale) + input_img = cv2.resize(self.input_img, (w, h), cv2.INTER_LANCZOS4) with torch.no_grad(): - bboxes = self.face_det.detect_faces(self.input_img, 0.97) + bboxes = self.face_det.detect_faces(input_img, 0.97) * scale for bbox in bboxes: if self.template_3points: landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) @@ -115,11 +128,84 @@ def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False, h, w, _ = self.input_img.shape self.det_faces, center_idx = get_center_face(self.det_faces, h, w) self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + return len(self.all_landmarks_5) def align_warp_face(self, save_cropped_path=None, border_mode='constant'): """Align and warp faces with face template. """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' for idx, landmark in enumerate(self.all_landmarks_5): # use 5 landmarks to get affine matrix affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0] @@ -131,9 +217,12 @@ def align_warp_face(self, save_cropped_path=None, border_mode='constant'): border_mode = cv2.BORDER_REFLECT101 elif border_mode == 'reflect': border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img cropped_face = cv2.warpAffine( - self.input_img, affine_matrix, self.face_size, borderMode=border_mode, - borderValue=(135, 133, 132)) # gray + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray self.cropped_faces.append(cropped_face) # save the cropped face if save_cropped_path is not None: @@ -194,3 +283,4 @@ def clean_all(self): self.cropped_faces = [] self.inverse_affine_matrices = [] self.det_faces = [] + self.pad_input_imgs = [] diff --git a/facexlib/utils/misc.py b/facexlib/utils/misc.py index af4d538..3d2475a 100644 --- a/facexlib/utils/misc.py +++ b/facexlib/utils/misc.py @@ -1,5 +1,6 @@ import cv2 import os +import os.path as osp import torch from torch.hub import download_url_to_file, get_dir from urllib.parse import urlparse @@ -73,3 +74,43 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) diff --git a/facexlib/visualization/__init__.py b/facexlib/visualization/__init__.py index 24acb3e..290fee7 100644 --- a/facexlib/visualization/__init__.py +++ b/facexlib/visualization/__init__.py @@ -1,4 +1,5 @@ from .vis_alignment import visualize_alignment from .vis_detection import visualize_detection +from .vis_headpose import visualize_headpose -__all__ = ['visualize_detection', 'visualize_alignment'] +__all__ = ['visualize_detection', 'visualize_alignment', 'visualize_headpose'] diff --git a/facexlib/visualization/vis_headpose.py b/facexlib/visualization/vis_headpose.py new file mode 100644 index 0000000..5797517 --- /dev/null +++ b/facexlib/visualization/vis_headpose.py @@ -0,0 +1,91 @@ +import cv2 +import numpy as np +from math import cos, sin + + +def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100): + """draw head pose axis.""" + + pitch = pitch * np.pi / 180 + yaw = -yaw * np.pi / 180 + roll = roll * np.pi / 180 + + if tdx is None or tdy is None: + height, width = img.shape[:2] + tdx = width / 2 + tdy = height / 2 + + # X axis pointing to right, drawn in red + x1 = size * (cos(yaw) * cos(roll)) + tdx + y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy + # Y axis poiting downside, drawn in green + x2 = size * (-cos(yaw) * sin(roll)) + tdx + y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy + # Z axis, out of the screen, drawn in blue + x3 = size * (sin(yaw)) + tdx + y3 = size * (-cos(yaw) * sin(pitch)) + tdy + + cv2.line(img, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3) + cv2.line(img, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2) + + return img + + +def draw_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.): + """draw head pose cube. + Where (tdx, tdy) is the translation of the face. + For pose we have [pitch yaw roll tdx tdy tdz scale_factor] + """ + + p = pitch * np.pi / 180 + y = -yaw * np.pi / 180 + r = roll * np.pi / 180 + if tdx is not None and tdy is not None: + face_x = tdx - 0.50 * size + face_y = tdy - 0.50 * size + else: + height, width = img.shape[:2] + face_x = width / 2 - 0.5 * size + face_y = height / 2 - 0.5 * size + + x1 = size * (cos(y) * cos(r)) + face_x + y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y + x2 = size * (-cos(y) * sin(r)) + face_x + y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y + x3 = size * (sin(y)) + face_x + y3 = size * (-cos(y) * sin(p)) + face_y + + # Draw base in red + cv2.line(img, (int(face_x), int(face_y)), (int(x1), int(y1)), (0, 0, 255), 3) + cv2.line(img, (int(face_x), int(face_y)), (int(x2), int(y2)), (0, 0, 255), 3) + cv2.line(img, (int(x2), int(y2)), (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), (0, 0, 255), 3) + cv2.line(img, (int(x1), int(y1)), (int(x1 + x2 - face_x), int(y1 + y2 - face_y)), (0, 0, 255), 3) + # Draw pillars in blue + cv2.line(img, (int(face_x), int(face_y)), (int(x3), int(y3)), (255, 0, 0), 2) + cv2.line(img, (int(x1), int(y1)), (int(x1 + x3 - face_x), int(y1 + y3 - face_y)), (255, 0, 0), 2) + cv2.line(img, (int(x2), int(y2)), (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), (255, 0, 0), 2) + cv2.line(img, (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (255, 0, 0), 2) + # Draw top in green + cv2.line(img, (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), + (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x3), int(y3)), (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), (0, 255, 0), 2) + cv2.line(img, (int(x3), int(y3)), (int(x3 + x2 - face_x), int(y3 + y2 - face_y)), (0, 255, 0), 2) + + return img + + +def visualize_headpose(img, yaw, pitch, roll, save_path=None, to_bgr=False): + img = np.copy(img) + if to_bgr: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + show_string = (f'y {yaw[0].item():.2f}, p {pitch[0].item():.2f}, ' + f'r {roll[0].item():.2f}') + cv2.putText(img, show_string, (30, img.shape[0] - 30), fontFace=1, fontScale=1, color=(0, 0, 255), thickness=2) + draw_pose_cube(img, yaw[0], pitch[0], roll[0], size=100) + draw_axis(img, yaw[0], pitch[0], roll[0], tdx=50, tdy=50, size=100) + # save img + if save_path is not None: + cv2.imwrite(save_path, img) diff --git a/inference/inference_headpose.py b/inference/inference_headpose.py new file mode 100644 index 0000000..142a68c --- /dev/null +++ b/inference/inference_headpose.py @@ -0,0 +1,54 @@ +import argparse +import cv2 +import numpy as np +import torch +from torchvision.transforms.functional import normalize + +from facexlib.detection import init_detection_model +from facexlib.headpose import init_headpose_model +from facexlib.utils.misc import img2tensor +from facexlib.visualization import visualize_headpose + + +def main(args): + # initialize model + det_net = init_detection_model(args.detection_model_name, half=args.half) + headpose_net = init_headpose_model(args.headpose_model_name, half=args.half) + + img = cv2.imread(args.img_path) + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + # x0, y0, x1, y1, confidence_score, five points (x, y) + bbox = list(map(int, bboxes[0])) + # crop face region + thld = 10 + h, w, _ = img.shape + top = max(bbox[1] - thld, 0) + bottom = min(bbox[3] + thld, h) + left = max(bbox[0] - thld, 0) + right = min(bbox[2] + thld, w) + + det_face = img[top:bottom, left:right, :].astype(np.float32) / 255. + + # resize + det_face = cv2.resize(det_face, (224, 224), interpolation=cv2.INTER_LINEAR) + det_face = img2tensor(np.copy(det_face), bgr2rgb=False) + + # normalize + normalize(det_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True) + det_face = det_face.unsqueeze(0).cuda() + + yaw, pitch, roll = headpose_net(det_face) + visualize_headpose(img, yaw, pitch, roll, args.save_path) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.') + parser.add_argument('--img_path', type=str, default='assets/test.jpg') + parser.add_argument('--save_path', type=str, default='assets/test_headpose.png') + parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50') + parser.add_argument('--headpose_model_name', type=str, default='hopenet') + parser.add_argument('--half', action='store_true') + args = parser.parse_args() + + main(args) diff --git a/inference/inference_hyperiqa.py b/inference/inference_hyperiqa.py new file mode 100644 index 0000000..c82efa2 --- /dev/null +++ b/inference/inference_hyperiqa.py @@ -0,0 +1,63 @@ +import argparse +import cv2 +import numpy as np +import os +import torch +import torchvision +from PIL import Image + +from facexlib.assessment import init_assessment_model +from facexlib.detection import init_detection_model + + +def main(args): + """Scripts about evaluating face quality. + Two steps: + 1) detect the face region and crop the face + 2) evaluate the face quality by hyperIQA + """ + # initialize model + det_net = init_detection_model(args.detection_model_name, half=False) + assess_net = init_assessment_model(args.assess_model_name, half=False) + + # specified face transformation in original hyperIQA + transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize((512, 384)), + torchvision.transforms.RandomCrop(size=224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + ]) + + img = cv2.imread(args.img_path) + img_name = os.path.basename(args.img_path) + basename, _ = os.path.splitext(img_name) + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + box = list(map(int, bboxes[0])) + pred_scores = [] + # BRG -> RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + for i in range(10): + detect_face = img[box[1]:box[3], box[0]:box[2], :] + detect_face = Image.fromarray(detect_face) + + detect_face = transforms(detect_face) + detect_face = torch.tensor(detect_face.cuda()).unsqueeze(0) + + pred = assess_net(detect_face) + pred_scores.append(float(pred.item())) + score = np.mean(pred_scores) + # quality score ranges from 0-100, a higher score indicates a better quality + print(f'{basename} {score:.4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--img_path', type=str, default='assets/test2.jpg') + parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50') + parser.add_argument('--assess_model_name', type=str, default='hypernet') + parser.add_argument('--half', action='store_true') + args = parser.parse_args() + + main(args) diff --git a/inference/inference_matting.py b/inference/inference_matting.py new file mode 100644 index 0000000..5d03398 --- /dev/null +++ b/inference/inference_matting.py @@ -0,0 +1,65 @@ +import argparse +import cv2 +import numpy as np +import torch.nn.functional as F +from torchvision.transforms.functional import normalize + +from facexlib.matting import init_matting_model +from facexlib.utils import img2tensor + + +def main(args): + modnet = init_matting_model() + + # read image + img = cv2.imread(args.img_path) / 255. + # unify image channels to 3 + if len(img.shape) == 2: + img = img[:, :, None] + if img.shape[2] == 1: + img = np.repeat(img, 3, axis=2) + elif img.shape[2] == 4: + img = img[:, :, 0:3] + + img_t = img2tensor(img, bgr2rgb=True, float32=True) + normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + img_t = img_t.unsqueeze(0).cuda() + + # resize image for input + _, _, im_h, im_w = img_t.shape + ref_size = 512 + if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: + if im_w >= im_h: + im_rh = ref_size + im_rw = int(im_w / im_h * ref_size) + elif im_w < im_h: + im_rw = ref_size + im_rh = int(im_h / im_w * ref_size) + else: + im_rh = im_h + im_rw = im_w + im_rw = im_rw - im_rw % 32 + im_rh = im_rh - im_rh % 32 + img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') + + # inference + _, _, matte = modnet(img_t, True) + + # resize and save matte + matte = F.interpolate(matte, size=(im_h, im_w), mode='area') + matte = matte[0][0].data.cpu().numpy() + cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) + + # get foreground + matte = matte[:, :, None] + foreground = img * matte + np.full(img.shape, 1) * (1 - matte) + cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--img_path', type=str, default='assets/test.jpg') + parser.add_argument('--save_path', type=str, default='test_matting.png') + args = parser.parse_args() + + main(args) diff --git a/inference/inference_parsing.py b/inference/inference_parsing.py new file mode 100644 index 0000000..b49700a --- /dev/null +++ b/inference/inference_parsing.py @@ -0,0 +1,74 @@ +import argparse +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from facexlib.parsing import init_parsing_model +from facexlib.utils.misc import img2tensor + + +def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0], + [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255], + [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255], + [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]] + # 0: 'background' + # attributions = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', + # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', + # 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', + # 16 'cloth', 17 'hair', 18 'hat'] + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + if save_anno_path is not None: + cv2.imwrite(save_anno_path, vis_parsing_anno) + + if save_vis_path is not None: + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + num_of_class = np.max(vis_parsing_anno) + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0) + + cv2.imwrite(save_vis_path, vis_im) + + +def main(img_path, output): + net = init_parsing_model(model_name='bisenet') + + img_name = os.path.basename(img_path) + img_basename = os.path.splitext(img_name)[0] + + img_input = cv2.imread(img_path) + img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR) + img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True) + img = torch.unsqueeze(img, 0).cuda() + + with torch.no_grad(): + out = net(img)[0] + out = out.squeeze(0).cpu().numpy().argmax(0) + + vis_parsing_maps( + img_input, + out, + stride=1, + save_anno_path=os.path.join(output, f'{img_basename}.png'), + save_vis_path=os.path.join(output, f'{img_basename}_vis.png')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png') + parser.add_argument('--output', type=str, default='results', help='output folder') + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + + main(args.input, args.output) diff --git a/requirements.txt b/requirements.txt index 24ce15a..fe67f8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,10 @@ +filterpy +numba numpy +numpy +opencv-python +Pillow +scipy +torch +torchvision +tqdm