diff --git a/audio.py b/audio.py index 32b20c449..32ab5fabe 100644 --- a/audio.py +++ b/audio.py @@ -97,7 +97,7 @@ def _linear_to_mel(spectogram): def _build_mel_basis(): assert hp.fmax <= hp.sample_rate // 2 - return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax) def _amp_to_db(x): diff --git a/inference.py b/inference.py index 90692521e..7fe0ac53e 100644 --- a/inference.py +++ b/inference.py @@ -159,21 +159,29 @@ def datagen(frames, mels): def _load(checkpoint_path): if device == 'cuda': - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False) else: checkpoint = torch.load(checkpoint_path, - map_location=lambda storage, loc: storage) + map_location='cpu', + weights_only=False) return checkpoint def load_model(path): - model = Wav2Lip() print("Load checkpoint from: {}".format(path)) checkpoint = _load(path) - s = checkpoint["state_dict"] - new_s = {} - for k, v in s.items(): - new_s[k.replace('module.', '')] = v - model.load_state_dict(new_s) + + # Check if it's a TorchScript model + if isinstance(checkpoint, torch.jit.ScriptModule): + print("Detected TorchScript model, loading directly...") + model = checkpoint + else: + # Regular checkpoint with state_dict + model = Wav2Lip() + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) model = model.to(device) return model.eval() @@ -252,6 +260,9 @@ def main(): model = load_model(args.checkpoint_path) print ("Model loaded") + # Ensure temp directory exists + os.makedirs('temp', exist_ok=True) + frame_h, frame_w = full_frames[0].shape[:-1] out = cv2.VideoWriter('temp/result.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h))