-
Notifications
You must be signed in to change notification settings - Fork 133
Expand file tree
/
Copy pathpredict.py
More file actions
120 lines (101 loc) · 3.77 KB
/
predict.py
File metadata and controls
120 lines (101 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
import os
import subprocess
import tempfile
import random
import numpy as np
import torch
import click
from cog import BasePredictor, Input, Path
import music21
from midi2audio import FluidSynth
from DatasetManager.chorale_dataset import ChoraleDataset
from DatasetManager.dataset_manager import DatasetManager
from DatasetManager.metadata import FermataMetadata, KeyMetadata, TickMetadata
from DeepBach.model_manager import DeepBach
class Predictor(BasePredictor):
def setup(self):
"""Load the model"""
# music21.environment.set("musicxmlPath", "/bin/true")
note_embedding_dim = 20
meta_embedding_dim = 20
num_layers = 2
lstm_hidden_size = 256
dropout_lstm = 0.5
linear_hidden_size = 256
batch_size = 256
num_epochs = 5
train = False
num_iterations = 500
sequence_length_ticks = 64
dataset_manager = DatasetManager()
metadatas = [FermataMetadata(), TickMetadata(subdivision=4), KeyMetadata()]
chorale_dataset_kwargs = {
"voice_ids": [0, 1, 2, 3],
"metadatas": metadatas,
"sequences_size": 8,
"subdivision": 4,
}
bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
name="bach_chorales", **chorale_dataset_kwargs
)
dataset = bach_chorales_dataset
self.deepbach = DeepBach(
dataset=dataset,
note_embedding_dim=note_embedding_dim,
meta_embedding_dim=meta_embedding_dim,
num_layers=num_layers,
lstm_hidden_size=lstm_hidden_size,
dropout_lstm=dropout_lstm,
linear_hidden_size=linear_hidden_size,
)
self.deepbach.load()
# load fluidsynth fo rmidi 2 audio conversion
self.fs = FluidSynth()
# self.converter = music21.converter.parse('path_to_musicxml.xml')
def predict(
self,
num_iterations: int = Input(
default=500,
description="Number of parallel pseudo-Gibbs sampling iterations",
),
sequence_length_ticks: int = Input(
default=64, ge=16, description="Length of the generated chorale (in ticks)"
),
output_type: str = Input(
default="audio",
choices=["midi", "audio"],
description="Output representation type: can be audio or midi",
),
seed: int = Input(default=-1, description="Random seed, -1 for random"),
) -> Path:
"""Score Generation"""
if seed >= 0:
random.seed(seed)
np.random.seed(seed)
torch.use_deterministic_algorithms(True)
torch.manual_seed(seed)
score, tensor_chorale, tensor_metadata = self.deepbach.generation(
num_iterations=num_iterations,
sequence_length_ticks=sequence_length_ticks,
)
if output_type == "audio":
output_path_wav = Path(tempfile.mkdtemp()) / "output.wav"
output_path_mp3 = Path(tempfile.mkdtemp()) / "output.mp3"
midi_score = score.write("midi")
self.fs.midi_to_audio(midi_score, str(output_path_wav))
subprocess.check_output(
[
"ffmpeg",
"-i",
str(output_path_wav),
"-af",
"silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence
str(output_path_mp3),
],
)
return output_path_mp3
elif output_type == "midi":
output_path_midi = Path(tempfile.mkdtemp()) / "output.mid"
score.write("midi", fp=output_path_midi)
return output_path_midi