-
Notifications
You must be signed in to change notification settings - Fork 0
/
apply_esn.py
60 lines (45 loc) · 1.88 KB
/
apply_esn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from data_io.midi_file import MODULATION, TEMPO, midi_tones_to_midi_file
from data_io.model_data import convert_raw_to_training_data, load_data_raw
from postprocessing.postprocessing import PostProcessorMC
import pickle
from esn import ESN
import pandas as pd
import numpy as np
POST_PROCESSING = True
FILE = 'models/esn/esn_tuned.pkl'
MEASURE_LEN = 16 # Lenght of a measure in symbols
RESERVOIR = 2000
W_IN = 0.3
BIAS = 0.9
SP = 1.0
LEAKING = 0.1
WASHOUT_TIME = 100
RIDGE_PARAM = 1
def apply_esn():
midi_raw = load_data_raw('F.txt')[:-16, :]
data, ove, ive = convert_raw_to_training_data(
midi_raw, window_length=1, flatten_output=True)
u, y = data
with open(FILE, 'rb') as inp:
print('Loading model...')
model = pickle.load(inp)
# model = ESN(u.shape[1], y.shape[1], reservoir_size=RESERVOIR, W_in_scaling=W_IN,
# bias_scaling=BIAS, spectral_radius=SP, leaking_rate=LEAKING, ridge_param=RIDGE_PARAM,
# ive=ive, ove=ove, washout_time=WASHOUT_TIME, silent=False)
# model.fit(u, y)
post_processor = None
if POST_PROCESSING:
post_processor = PostProcessorMC(ove, midi_raw, measure_length=MEASURE_LEN)
u_drive = u[-300:, :]
y_drive = y[-300:, :]
predicted_sequence = model.predict_sequence(u_drive, y_drive, 486, post_processor)
pd.DataFrame(predicted_sequence).to_csv("analyseData/esn.txt", header=None, index=None, sep='\t')
output_file = "output_midi_files/pred_esn_mc.mid"
midi_tones_to_midi_file(predicted_sequence, str(output_file), tempo=TEMPO, modulation=MODULATION)
output_file = "output_midi_files/full_seq_plus_pred_esn_mc.mid"
full_sequence = midi_raw.copy()
song = np.concatenate((full_sequence, predicted_sequence))
midi_tones_to_midi_file(song, str(output_file), tempo=TEMPO, modulation=MODULATION)
print('Done!')
if __name__ == '__main__':
apply_esn()