-
Notifications
You must be signed in to change notification settings - Fork 32
/
predict.py
111 lines (94 loc) · 3.83 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import types
import librosa
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
import htsat_config
from cog import BasePredictor, Input, Path
from data_processor import MusdbDataset
from models.asp_model import AutoTaggingWarpper, SeparatorModel, ZeroShotASP
from models.htsat import HTSAT_Swin_Transformer
from sed_model import SEDWrapper
from utils import prepprocess_audio
def get_inference_configs():
config = types.SimpleNamespace()
config.ckpt_path = "pretrained/zeroshot_asp_full.ckpt"
config.sed_ckpt_path = "pretrained/htsat_audioset_2048d.ckpt"
config.wave_output_path = "predict_outputs"
config.test_key = "query_name"
config.test_type = "mix"
config.loss_type = "mae"
config.infer_type = "mean"
config.sample_rate = 32000
config.segment_frames = 200
config.hop_samples = 320
config.energy_thres = 0.1
config.using_whiting = False
config.latent_dim = 2048
config.classes_num = 527
config.overlap_rate = 0.5
config.num_workers = 1
return config
def load_models(config):
sed_model = HTSAT_Swin_Transformer(
spec_size=htsat_config.htsat_spec_size,
patch_size=htsat_config.htsat_patch_size,
in_chans=1,
num_classes=htsat_config.classes_num,
window_size=htsat_config.htsat_window_size,
config=htsat_config,
depths=htsat_config.htsat_depth,
embed_dim=htsat_config.htsat_dim,
patch_stride=htsat_config.htsat_stride,
num_heads=htsat_config.htsat_num_head,
)
at_model = SEDWrapper(sed_model=sed_model, config=htsat_config, dataset=None)
ckpt = torch.load(config.sed_ckpt_path, map_location="cpu")
at_model.load_state_dict(ckpt["state_dict"])
at_wrapper = AutoTaggingWarpper(
at_model=at_model, config=config, target_keys=[config.test_key]
)
asp_model = ZeroShotASP(channels=1, config=config, at_model=at_model, dataset=None)
ckpt = torch.load(config.ckpt_path, map_location="cpu")
asp_model.load_state_dict(ckpt["state_dict"], strict=False)
return at_wrapper, asp_model
def get_dataloader_from_sound_file(sound_file_path, config):
signal, sampling_rate = librosa.load(str(sound_file_path), sr=None)
signal = prepprocess_audio(
signal[:, None], sampling_rate, config.sample_rate, config.test_type
)
signal = np.array([signal, signal]) # Duplicate signal for later use
dataset = MusdbDataset(tracks=[signal])
data_loader = DataLoader(dataset, num_workers=config.num_workers, batch_size=1, shuffle=False)
return data_loader
class Predictor(BasePredictor):
def setup(self):
self.config = get_inference_configs()
os.makedirs(self.config.wave_output_path, exist_ok=True)
self.at_wrapper, self.asp_model = load_models(self.config)
def predict(
self,
mix_file: Path = Input(description="Reference sound to extract source from"),
query_file: Path = Input(description="Query sound to be searched and extracted from mix"),
) -> Path:
ref_loader = get_dataloader_from_sound_file(str(mix_file), self.config)
query_loader = get_dataloader_from_sound_file(str(query_file), self.config)
trainer = pl.Trainer(gpus=1)
trainer.test(self.at_wrapper, test_dataloaders=query_loader)
avg_at = self.at_wrapper.avg_at
exp_model = SeparatorModel(
model=self.asp_model,
config=self.config,
target_keys=[self.config.test_key],
avg_at=avg_at,
using_wiener=False,
calc_sdr=False,
output_wav=True,
)
trainer.test(exp_model, test_dataloaders=ref_loader)
prediction_path = os.path.join(
self.config.wave_output_path, f"0_{self.config.test_key}_pred_(0.0).wav"
)
return prediction_path