forked from kuielab/mdx-net-submission
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
25 lines (17 loc) · 772 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import soundfile as sf
import torch
from evaluator.music_demixing import MusicDemixingPredictor
from mdxnet import MDXNet
device = torch.device('cpu')
class Predictor(MusicDemixingPredictor):
def prediction_setup(self):
self.model = MDXNet(device, 'leaderboard_A')
def prediction(self, mixture_file_path, bass_file_path, drums_file_path, other_file_path, vocals_file_path):
file_paths = [bass_file_path, drums_file_path, other_file_path, vocals_file_path]
mix = sf.read(mixture_file_path)[0].T
sources = self.model.demix(mix)
for i in range(len(sources)):
sf.write(file_paths[i], sources[i].T, samplerate=44100)
submission = Predictor()
submission.run()
print("Successfully completed music demixing...")