Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ torch==1.1.0
torchvision==0.3.0
tqdm==4.45.0
numba==0.48
flask
309 changes: 309 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
from flask import Flask, request, jsonify
import json
# start a web server

from os import listdir, path
import numpy as np
import scipy, cv2, os, sys, argparse, audio
import json, subprocess, random, string
from tqdm import tqdm
from glob import glob
import torch, face_detection
from models import Wav2Lip
import platform


app = Flask(__name__)


class Args:
def __init__(self):
self.checkpoint_path = "checkpoints/wav2lip.pth"
self.face = ""
self.audio = "audio"
self.outfile = 'results/result_voice.mp4'
self.static = False
self.fps = 25.
self.pads = [0, 10, 0, 0]
self.face_det_batch_size = 16
self.wav2lip_batch_size = 128
self.resize_factor = 1
self.crop = [0, -1, 0, -1]
self.box = [-1, -1, -1, -1]
self.rotate = False
self.nosmooth = False
self.img_size = 96

args = Args()

global model
model_path ="checkpoints/wav2lip.pth"

if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
args.static = True

def get_smoothened_boxes(boxes, T):
global model
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i: i + T]
boxes[i] = np.mean(window, axis=0)
return boxes


def face_detect(images):
global model
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)

batch_size = args.face_det_batch_size

while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break

results = []
pady1, pady2, padx1, padx2 = args.pads
fnr = 0
x1 = 0
y1 = 0
x2 = 100
y2 = 100
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
print(f'Face not detected in {fnr}! Ensure the video contains a face in all the frames.')
# raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
else:
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)

results.append([x1, y1, x2, y2])
fnr = fnr + 1

boxes = np.array(results)
if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

del detector
return results


def datagen(frames, mels):
global model
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

if args.box[0] == -1:
if not args.static:
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
else:
face_det_results = face_detect([frames[0]])
else:
print('Using the specified bounding box instead of face detection...')
y1, y2, x1, x2 = args.box
face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]

for i, m in enumerate(mels):
idx = 0 if args.static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()

face = cv2.resize(face, (args.img_size, args.img_size))

img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)

if len(img_batch) >= args.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

img_masked = img_batch.copy()
img_masked[:, args.img_size // 2:] = 0

img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

img_masked = img_batch.copy()
img_masked[:, args.img_size // 2:] = 0

img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

yield img_batch, mel_batch, frame_batch, coords_batch


mel_step_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))


def _load(checkpoint_path):
global model
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint


def load_model(path):
global model
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)

model = model.to(device)
return model.eval()


def main():
global model
if not os.path.isfile(args.face):
raise ValueError('--face argument must be a valid path to video/image file')

elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(args.face)]
fps = args.fps

else:
video_stream = cv2.VideoCapture(args.face)
fps = video_stream.get(cv2.CAP_PROP_FPS)

print('Reading video frames...')

full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
if args.resize_factor > 1:
frame = cv2.resize(frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor))

if args.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)

y1, y2, x1, x2 = args.crop
if x2 == -1: x2 = frame.shape[1]
if y2 == -1: y2 = frame.shape[0]

frame = frame[y1:y2, x1:x2]

full_frames.append(frame)

print("Number of frames available for inference: " + str(len(full_frames)))

if not args.audio.endswith('.wav'):
print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav')

subprocess.call(command, shell=True)
args.audio = 'temp/temp.wav'

wav = audio.load_wav(args.audio, 16000)
mel = audio.melspectrogram(wav)
print(mel.shape)

if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')

mel_chunks = []
mel_idx_multiplier = 80. / fps
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
break
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1

print("Length of mel chunks: {}".format(len(mel_chunks)))

full_frames = full_frames[:len(mel_chunks)]

batch_size = args.wav2lip_batch_size
gen = datagen(full_frames.copy(), mel_chunks)

for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen,
total=int(
np.ceil(float(len(mel_chunks)) / batch_size)))):
if i == 0:

frame_h, frame_w = full_frames[0].shape[:-1]
out = cv2.VideoWriter('temp/result.avi',
cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))

img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

with torch.no_grad():
pred = model(mel_batch, img_batch)

pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))

f[y1:y2, x1:x2] = p
out.write(f)

out.release()

command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile)
subprocess.call(command, shell=platform.system() != 'Windows')


@app.route('/synthesize', methods=['POST'])
def synthesize():
global model
try:
req = request.get_json(force=True) # force=True, ignore mimetype and always try to parse JSON
if args.checkpoint_path != req.get('checkpoint_path', ''):
args.checkpoint_path = req.get('checkpoint_path', '')
model = load_model(args.checkpoint_path)
args.face = req.get('face', '')
args.audio = req.get('audio', '')
args.outfile = req.get('outfile', 'results/result_voice.mp4')

# calling the main function or the function containing your logic
main()
response = {"status": "success"}
except Exception as e:
response = {"status": "failed", "error": str(e)}

resp = jsonify(response)
resp.status_code = 200
return resp

if __name__ == '__main__':
model = load_model(args.checkpoint_path)
app.run(host='0.0.0.0', port=1206)